@@ -67,10 +67,14 @@ def from_package(package):
67
67
68
68
69
69
@contextlib .contextmanager
70
- def _tempfile (reader , suffix = '' ,
71
- # gh-93353: Keep a reference to call os.remove() in late Python
72
- # finalization.
73
- * , _os_remove = os .remove ):
70
+ def _tempfile (
71
+ reader ,
72
+ suffix = '' ,
73
+ # gh-93353: Keep a reference to call os.remove() in late Python
74
+ # finalization.
75
+ * ,
76
+ _os_remove = os .remove ,
77
+ ):
74
78
# Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
75
79
# blocks due to the need to close the temporary file to work on Windows
76
80
# properly.
@@ -89,13 +93,30 @@ def _tempfile(reader, suffix='',
89
93
pass
90
94
91
95
96
+ def _temp_file (path ):
97
+ return _tempfile (path .read_bytes , suffix = path .name )
98
+
99
+
100
+ def _is_present_dir (path : Traversable ) -> bool :
101
+ """
102
+ Some Traversables implement ``is_dir()`` to raise an
103
+ exception (i.e. ``FileNotFoundError``) when the
104
+ directory doesn't exist. This function wraps that call
105
+ to always return a boolean and only return True
106
+ if there's a dir and it exists.
107
+ """
108
+ with contextlib .suppress (FileNotFoundError ):
109
+ return path .is_dir ()
110
+ return False
111
+
112
+
92
113
@functools .singledispatch
93
114
def as_file (path ):
94
115
"""
95
116
Given a Traversable object, return that object as a
96
117
path on the local file system in a context manager.
97
118
"""
98
- return _tempfile (path . read_bytes , suffix = path . name )
119
+ return _temp_dir (path ) if _is_present_dir ( path ) else _temp_file ( path )
99
120
100
121
101
122
@as_file .register (pathlib .Path )
@@ -105,3 +126,34 @@ def _(path):
105
126
Degenerate behavior for pathlib.Path objects.
106
127
"""
107
128
yield path
129
+
130
+
131
+ @contextlib .contextmanager
132
+ def _temp_path (dir : tempfile .TemporaryDirectory ):
133
+ """
134
+ Wrap tempfile.TemporyDirectory to return a pathlib object.
135
+ """
136
+ with dir as result :
137
+ yield pathlib .Path (result )
138
+
139
+
140
+ @contextlib .contextmanager
141
+ def _temp_dir (path ):
142
+ """
143
+ Given a traversable dir, recursively replicate the whole tree
144
+ to the file system in a context manager.
145
+ """
146
+ assert path .is_dir ()
147
+ with _temp_path (tempfile .TemporaryDirectory ()) as temp_dir :
148
+ yield _write_contents (temp_dir , path )
149
+
150
+
151
+ def _write_contents (target , source ):
152
+ child = target .joinpath (source .name )
153
+ if source .is_dir ():
154
+ child .mkdir ()
155
+ for item in source .iterdir ():
156
+ _write_contents (child , item )
157
+ else :
158
+ child .open ('wb' ).write (source .read_bytes ())
159
+ return child
0 commit comments