diff --git a/allennlp/models/archival.py b/allennlp/models/archival.py index be915c14574..48be2887d6f 100644 --- a/allennlp/models/archival.py +++ b/allennlp/models/archival.py @@ -299,7 +299,26 @@ def extracted_archive(resolved_archive_file, cleanup=True): tempdir = tempfile.mkdtemp() logger.info(f"extracting archive file {resolved_archive_file} to temp dir {tempdir}") with tarfile.open(resolved_archive_file, "r:gz") as archive: - archive.extractall(tempdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, tempdir) yield tempdir finally: if tempdir is not None and cleanup: diff --git a/allennlp/tools/archive_surgery.py b/allennlp/tools/archive_surgery.py index 3cba3f57169..e7a0b35644a 100644 --- a/allennlp/tools/archive_surgery.py +++ b/allennlp/tools/archive_surgery.py @@ -67,7 +67,26 @@ def main(): # Extract archive to temp dir tempdir = tempfile.mkdtemp() with tarfile.open(archive_file, "r:gz") as archive: - archive.extractall(tempdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, tempdir) atexit.register(lambda: shutil.rmtree(tempdir)) config_path = os.path.join(tempdir, CONFIG_NAME) diff --git a/tests/models/archival_test.py b/tests/models/archival_test.py index 4a40588bc13..7b875a297b0 100644 --- a/tests/models/archival_test.py +++ b/tests/models/archival_test.py @@ -159,7 +159,26 @@ def test_include_in_archive(self): # Assert that the additional targets were archived with tempfile.TemporaryDirectory() as tempdir: with tarfile.open(serialization_dir / "model.tar.gz", "r:gz") as archive: - archive.extractall(tempdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, tempdir) assert os.path.isfile(os.path.join(tempdir, "metrics_epoch_0.json")) assert os.path.isfile(os.path.join(tempdir, "metrics_epoch_1.json")) assert not os.path.isfile(os.path.join(tempdir, "metrics.json"))