mirror of
				https://github.com/yt-dlp/yt-dlp.git
				synced 2025-10-30 22:25:19 +00:00 
			
		
		
		
	[utils] traverse_obj: Support xml.etree.ElementTree.Element (#8911)
				
					
				
			Authored by: Grub4K
This commit is contained in:
		| @@ -2340,6 +2340,58 @@ Line 1 | ||||
|         self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], | ||||
|                          msg='function on a `re.Match` should give group name as well') | ||||
| 
 | ||||
|         # Test xml.etree.ElementTree.Element as input obj | ||||
|         etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?> | ||||
|         <data> | ||||
|             <country name="Liechtenstein"> | ||||
|                 <rank>1</rank> | ||||
|                 <year>2008</year> | ||||
|                 <gdppc>141100</gdppc> | ||||
|                 <neighbor name="Austria" direction="E"/> | ||||
|                 <neighbor name="Switzerland" direction="W"/> | ||||
|             </country> | ||||
|             <country name="Singapore"> | ||||
|                 <rank>4</rank> | ||||
|                 <year>2011</year> | ||||
|                 <gdppc>59900</gdppc> | ||||
|                 <neighbor name="Malaysia" direction="N"/> | ||||
|             </country> | ||||
|             <country name="Panama"> | ||||
|                 <rank>68</rank> | ||||
|                 <year>2011</year> | ||||
|                 <gdppc>13600</gdppc> | ||||
|                 <neighbor name="Costa Rica" direction="W"/> | ||||
|                 <neighbor name="Colombia" direction="E"/> | ||||
|             </country> | ||||
|         </data>''') | ||||
|         self.assertEqual(traverse_obj(etree, ''), etree, | ||||
|                          msg='empty str key should return the element itself') | ||||
|         self.assertEqual(traverse_obj(etree, 'country'), list(etree), | ||||
|                          msg='str key should lead all children with that tag name') | ||||
|         self.assertEqual(traverse_obj(etree, ...), list(etree), | ||||
|                          msg='`...` as key should return all children') | ||||
|         self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]], | ||||
|                          msg='function as key should get element as value') | ||||
|         self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]], | ||||
|                          msg='function as key should get index as key') | ||||
|         self.assertEqual(traverse_obj(etree, 0), etree[0], | ||||
|                          msg='int key should return the nth child') | ||||
|         self.assertEqual(traverse_obj(etree, './/neighbor/@name'), | ||||
|                          ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'], | ||||
|                          msg='`@<attribute>` at end of path should give that attribute') | ||||
|         self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None], | ||||
|                          msg='`@<nonexistant>` at end of path should give `None`') | ||||
|         self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'}, | ||||
|                          msg='`@` should give the full attribute dict') | ||||
|         self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'], | ||||
|                          msg='`text()` at end of path should give the inner text') | ||||
|         self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'], | ||||
|                          msg='full python xpath features should be supported') | ||||
|         self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein', | ||||
|                          msg='special transformations should act on current element') | ||||
|         self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100], | ||||
|                          msg='special transformations should act on current element') | ||||
| 
 | ||||
|     def test_http_header_dict(self): | ||||
|         headers = HTTPHeaderDict() | ||||
|         headers['ytdl-test'] = b'0' | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import contextlib | ||||
| import inspect | ||||
| import itertools | ||||
| import re | ||||
| import xml.etree.ElementTree | ||||
| 
 | ||||
| from ._utils import ( | ||||
|     IDENTITY, | ||||
| @@ -118,7 +119,7 @@ def traverse_obj( | ||||
|             branching = True | ||||
|             if isinstance(obj, collections.abc.Mapping): | ||||
|                 result = obj.values() | ||||
|             elif is_iterable_like(obj): | ||||
|             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): | ||||
|                 result = obj | ||||
|             elif isinstance(obj, re.Match): | ||||
|                 result = obj.groups() | ||||
| @@ -132,7 +133,7 @@ def traverse_obj( | ||||
|             branching = True | ||||
|             if isinstance(obj, collections.abc.Mapping): | ||||
|                 iter_obj = obj.items() | ||||
|             elif is_iterable_like(obj): | ||||
|             elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): | ||||
|                 iter_obj = enumerate(obj) | ||||
|             elif isinstance(obj, re.Match): | ||||
|                 iter_obj = itertools.chain( | ||||
| @@ -168,7 +169,7 @@ def traverse_obj( | ||||
|                 result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) | ||||
| 
 | ||||
|         elif isinstance(key, (int, slice)): | ||||
|             if is_iterable_like(obj, collections.abc.Sequence): | ||||
|             if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)): | ||||
|                 branching = isinstance(key, slice) | ||||
|                 with contextlib.suppress(IndexError): | ||||
|                     result = obj[key] | ||||
| @@ -176,6 +177,34 @@ def traverse_obj( | ||||
|                 with contextlib.suppress(IndexError): | ||||
|                     result = str(obj)[key] | ||||
| 
 | ||||
|         elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): | ||||
|             xpath, _, special = key.rpartition('/') | ||||
|             if not special.startswith('@') and special != 'text()': | ||||
|                 xpath = key | ||||
|                 special = None | ||||
| 
 | ||||
|             # Allow abbreviations of relative paths, absolute paths error | ||||
|             if xpath.startswith('/'): | ||||
|                 xpath = f'.{xpath}' | ||||
|             elif xpath and not xpath.startswith('./'): | ||||
|                 xpath = f'./{xpath}' | ||||
| 
 | ||||
|             def apply_specials(element): | ||||
|                 if special is None: | ||||
|                     return element | ||||
|                 if special == '@': | ||||
|                     return element.attrib | ||||
|                 if special.startswith('@'): | ||||
|                     return try_call(element.attrib.get, args=(special[1:],)) | ||||
|                 if special == 'text()': | ||||
|                     return element.text | ||||
|                 assert False, f'apply_specials is missing case for {special!r}' | ||||
| 
 | ||||
|             if xpath: | ||||
|                 result = list(map(apply_specials, obj.iterfind(xpath))) | ||||
|             else: | ||||
|                 result = apply_specials(obj) | ||||
| 
 | ||||
|         return branching, result if branching else (result,) | ||||
| 
 | ||||
|     def lazy_last(iterable): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Simon Sawicki
					Simon Sawicki