diff --git a/cyaron/tests/compare_test.py b/cyaron/tests/compare_test.py index bc0a830..c0abfc8 100644 --- a/cyaron/tests/compare_test.py +++ b/cyaron/tests/compare_test.py @@ -4,7 +4,7 @@ import shutil import tempfile import subprocess -from cyaron import IO, Compare, log +from cyaron import IO, Compare, log, escape_path from cyaron.output_capture import captured_output from cyaron.graders.mismatch import * from cyaron.compare import CompareMismatch @@ -83,13 +83,13 @@ def test_fulltext_program(self): try: with captured_output() as (out, err): - Compare.program(f"{sys.executable} correct.py", - f"{sys.executable} incorrect.py", + Compare.program(f"{escape_path(sys.executable)} correct.py", + f"{escape_path(sys.executable)} incorrect.py", std=io, input=io, grader="FullText") except CompareMismatch as e: - self.assertEqual(e.name, f'{sys.executable} incorrect.py') + self.assertEqual(e.name, f'{escape_path(sys.executable)} incorrect.py') e = e.mismatch self.assertEqual(e.content, '2\n') self.assertEqual(e.std, '1\n') @@ -105,7 +105,7 @@ def test_fulltext_program(self): self.assertTrue(False) result = out.getvalue().strip() - correct_out = f'{sys.executable} correct.py: Correct \n{sys.executable} incorrect.py: !!!INCORRECT!!! Hash mismatch: read 53c234e5e8472b6ac51c1ae1cab3fe06fad053beb8ebfd8977b010655bfdd3c3, expected 4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865' + correct_out = f'{escape_path(sys.executable)} correct.py: Correct \n{escape_path(sys.executable)} incorrect.py: !!!INCORRECT!!! Hash mismatch: read 53c234e5e8472b6ac51c1ae1cab3fe06fad053beb8ebfd8977b010655bfdd3c3, expected 4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865' self.assertEqual(result, correct_out) def test_file_input(self): @@ -122,13 +122,13 @@ def test_file_input(self): io.input_writeln("233") with captured_output() as (out, err): - Compare.program(f"{sys.executable} correct.py", - std_program=f"{sys.executable} std.py", + Compare.program(f"{escape_path(sys.executable)} correct.py", + std_program=f"{escape_path(sys.executable)} std.py", input=io, grader="NOIPStyle") result = out.getvalue().strip() - correct_out = f'{sys.executable} correct.py: Correct' + correct_out = f'{escape_path(sys.executable)} correct.py: Correct' self.assertEqual(result, correct_out) def test_concurrent(self): diff --git a/cyaron/tests/io_test.py b/cyaron/tests/io_test.py index 02b5a98..7e172d7 100644 --- a/cyaron/tests/io_test.py +++ b/cyaron/tests/io_test.py @@ -5,7 +5,7 @@ import shutil import tempfile import subprocess -from cyaron import IO +from cyaron import IO, escape_path from cyaron.output_capture import captured_output @@ -92,7 +92,7 @@ def test_output_gen_time_limit_exceeded(self): abs_input_filename: str = os.path.abspath(input_filename) with self.assertRaises(subprocess.TimeoutExpired): test.input_writeln(abs_input_filename) - test.output_gen(f'"{sys.executable}" long_time.py', + test.output_gen(f'{escape_path(sys.executable)} long_time.py', time_limit=TIMEOUT) time.sleep(WAIT_TIME) try: @@ -108,7 +108,7 @@ def test_output_gen_time_limit_not_exceeded(self): "print(1)") with IO("test_gen.in", "test_gen.out") as test: - test.output_gen(f'"{sys.executable}" short_time.py', + test.output_gen(f'{escape_path(sys.executable)} short_time.py', time_limit=0.5) with open("test_gen.out", encoding="utf-8") as f: output = f.read() diff --git a/cyaron/utils.py b/cyaron/utils.py index 4ef5ad8..38d4c2e 100644 --- a/cyaron/utils.py +++ b/cyaron/utils.py @@ -1,11 +1,13 @@ """Some utility functions.""" +import os +import shlex import sys import random from typing import cast, Any, Dict, Iterable, Tuple, Union __all__ = [ "ati", "list_like", "int_like", "strtolines", "make_unicode", - "unpack_kwargs", "process_args" + "unpack_kwargs", "process_args", "escape_path" ] @@ -79,3 +81,10 @@ def process_args(): for s in sys.argv: if s.startswith("--randseed="): random.seed(s.split("=")[1]) + +def escape_path(path: str) -> str: + """Escape the path.""" + if os.name == 'nt': + return '"' + path.replace('\\', '/') + '"' + else: + return shlex.quote(path)