85 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			85 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import inspect | ||
|  | import typing as t | ||
|  | from functools import WRAPPER_ASSIGNMENTS | ||
|  | from functools import wraps | ||
|  | 
 | ||
|  | from .utils import _PassArg | ||
|  | from .utils import pass_eval_context | ||
|  | 
 | ||
|  | V = t.TypeVar("V") | ||
|  | 
 | ||
|  | 
 | ||
|  | def async_variant(normal_func):  # type: ignore | ||
|  |     def decorator(async_func):  # type: ignore | ||
|  |         pass_arg = _PassArg.from_obj(normal_func) | ||
|  |         need_eval_context = pass_arg is None | ||
|  | 
 | ||
|  |         if pass_arg is _PassArg.environment: | ||
|  | 
 | ||
|  |             def is_async(args: t.Any) -> bool: | ||
|  |                 return t.cast(bool, args[0].is_async) | ||
|  | 
 | ||
|  |         else: | ||
|  | 
 | ||
|  |             def is_async(args: t.Any) -> bool: | ||
|  |                 return t.cast(bool, args[0].environment.is_async) | ||
|  | 
 | ||
|  |         # Take the doc and annotations from the sync function, but the | ||
|  |         # name from the async function. Pallets-Sphinx-Themes | ||
|  |         # build_function_directive expects __wrapped__ to point to the | ||
|  |         # sync function. | ||
|  |         async_func_attrs = ("__module__", "__name__", "__qualname__") | ||
|  |         normal_func_attrs = tuple(set(WRAPPER_ASSIGNMENTS).difference(async_func_attrs)) | ||
|  | 
 | ||
|  |         @wraps(normal_func, assigned=normal_func_attrs) | ||
|  |         @wraps(async_func, assigned=async_func_attrs, updated=()) | ||
|  |         def wrapper(*args, **kwargs):  # type: ignore | ||
|  |             b = is_async(args) | ||
|  | 
 | ||
|  |             if need_eval_context: | ||
|  |                 args = args[1:] | ||
|  | 
 | ||
|  |             if b: | ||
|  |                 return async_func(*args, **kwargs) | ||
|  | 
 | ||
|  |             return normal_func(*args, **kwargs) | ||
|  | 
 | ||
|  |         if need_eval_context: | ||
|  |             wrapper = pass_eval_context(wrapper) | ||
|  | 
 | ||
|  |         wrapper.jinja_async_variant = True | ||
|  |         return wrapper | ||
|  | 
 | ||
|  |     return decorator | ||
|  | 
 | ||
|  | 
 | ||
|  | _common_primitives = {int, float, bool, str, list, dict, tuple, type(None)} | ||
|  | 
 | ||
|  | 
 | ||
|  | async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V": | ||
|  |     # Avoid a costly call to isawaitable | ||
|  |     if type(value) in _common_primitives: | ||
|  |         return t.cast("V", value) | ||
|  | 
 | ||
|  |     if inspect.isawaitable(value): | ||
|  |         return await t.cast("t.Awaitable[V]", value) | ||
|  | 
 | ||
|  |     return t.cast("V", value) | ||
|  | 
 | ||
|  | 
 | ||
|  | async def auto_aiter( | ||
|  |     iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", | ||
|  | ) -> "t.AsyncIterator[V]": | ||
|  |     if hasattr(iterable, "__aiter__"): | ||
|  |         async for item in t.cast("t.AsyncIterable[V]", iterable): | ||
|  |             yield item | ||
|  |     else: | ||
|  |         for item in t.cast("t.Iterable[V]", iterable): | ||
|  |             yield item | ||
|  | 
 | ||
|  | 
 | ||
|  | async def auto_to_list( | ||
|  |     value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", | ||
|  | ) -> t.List["V"]: | ||
|  |     return [x async for x in auto_aiter(value)] |