diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionService.java b/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionService.java index 23c1ccee..b6439263 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionService.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionService.java @@ -1,9 +1,11 @@ package work.slhaf.partner.core.action.runner.execution; import lombok.Data; +import work.slhaf.partner.core.action.runner.policy.WrappedLaunchSpec; import java.io.BufferedReader; import java.io.InputStreamReader; +import java.io.File; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -37,15 +39,13 @@ public class CommandExecutionService { return exec(commands.toArray(new String[0])); } - public Result exec(String... commands) { + public Result exec(WrappedLaunchSpec launchSpec) { Result result = new Result(); List output = new ArrayList<>(); List error = new ArrayList<>(); try { - Process process = new ProcessBuilder(commands) - .redirectErrorStream(false) - .start(); + Process process = startProcess(launchSpec); Thread stdoutThread = new Thread(() -> { try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) { @@ -85,15 +85,17 @@ public class CommandExecutionService { return result; } + public Result exec(String... commands) { + return exec(defaultLaunchSpec(commands)); + } + public CommandSession createSessionTask(List commands) { return createSessionTask(commands.toArray(new String[0])); } - public CommandSession createSessionTask(String... commands) { + public CommandSession createSessionTask(WrappedLaunchSpec launchSpec) { try { - Process process = new ProcessBuilder(commands) - .redirectErrorStream(false) - .start(); + Process process = startProcess(launchSpec); CommandSession session = new CommandSession(); StringBuilder stdoutBuffer = new StringBuilder(); StringBuilder stderrBuffer = new StringBuilder(); @@ -110,6 +112,10 @@ public class CommandExecutionService { } } + public CommandSession createSessionTask(String... commands) { + return createSessionTask(defaultLaunchSpec(commands)); + } + private void readToBuffer(java.io.InputStream inputStream, StringBuilder buffer) { try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { String line; @@ -125,6 +131,31 @@ public class CommandExecutionService { } } + private Process startProcess(WrappedLaunchSpec launchSpec) throws Exception { + ProcessBuilder processBuilder = new ProcessBuilder(); + List command = new ArrayList<>(); + command.add(launchSpec.getCommand()); + command.addAll(launchSpec.getArgs()); + processBuilder.command(command); + processBuilder.redirectErrorStream(false); + if (launchSpec.getWorkingDirectory() != null && !launchSpec.getWorkingDirectory().isBlank()) { + processBuilder.directory(new File(launchSpec.getWorkingDirectory())); + } + Map environment = processBuilder.environment(); + environment.clear(); + environment.putAll(launchSpec.getEnvironment()); + return processBuilder.start(); + } + + private WrappedLaunchSpec defaultLaunchSpec(String... commands) { + return new WrappedLaunchSpec( + commands[0], + List.of(commands).subList(1, commands.length), + null, + System.getenv() + ); + } + @Data public static class Result { private boolean ok; diff --git a/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionService.java b/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionService.java index eba315e2..5a841315 100644 --- a/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionService.java +++ b/Partner-Core/src/main/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionService.java @@ -6,9 +6,7 @@ import work.slhaf.partner.core.action.runner.policy.ExecutionPolicyRegistry; import work.slhaf.partner.core.action.runner.policy.WrappedLaunchSpec; import java.io.File; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import static work.slhaf.partner.core.action.ActionCore.ORIGIN_LOCATION; @@ -22,10 +20,7 @@ public class OriginExecutionService { File file = new File(resolveOriginPath(metaAction)); String[] commands = CommandExecutionService.INSTANCE.buildFileExecutionCommands(metaAction.getLauncher(), metaAction.getParams(), file.getAbsolutePath()); WrappedLaunchSpec wrapped = ExecutionPolicyRegistry.INSTANCE.prepare(Arrays.stream(commands).toList()); - List wrappedCommands = new ArrayList<>(); - wrappedCommands.add(wrapped.getCommand()); - wrappedCommands.addAll(wrapped.getArgs()); - CommandExecutionService.Result execResult = CommandExecutionService.INSTANCE.exec(wrappedCommands); + CommandExecutionService.Result execResult = CommandExecutionService.INSTANCE.exec(wrapped); response.setOk(execResult.isOk()); response.setData(execResult.getTotal()); return response; diff --git a/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionServiceTest.java b/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionServiceTest.java index 53e7f189..d90bcf06 100644 --- a/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionServiceTest.java +++ b/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/CommandExecutionServiceTest.java @@ -2,6 +2,7 @@ package work.slhaf.partner.core.action.runner.execution; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import work.slhaf.partner.core.action.runner.policy.WrappedLaunchSpec; import java.io.IOException; import java.nio.file.Files; @@ -105,6 +106,32 @@ class CommandExecutionServiceTest { Assertions.assertEquals("oops", session.getStderrBuffer().toString()); } + @Test + void testExecWrappedLaunchSpecAppliesWorkingDirectory(@org.junit.jupiter.api.io.TempDir Path tempDir) { + CommandExecutionService.Result result = service.exec(new WrappedLaunchSpec( + "sh", + List.of("-lc", "pwd"), + tempDir.toString(), + System.getenv() + )); + + Assertions.assertTrue(result.isOk()); + Assertions.assertEquals(tempDir.toString(), result.getTotal()); + } + + @Test + void testExecWrappedLaunchSpecAppliesEnvironmentOverride() { + CommandExecutionService.Result result = service.exec(new WrappedLaunchSpec( + "sh", + List.of("-lc", "printf '%s' \"$PARTNER_TEST_ENV\""), + null, + Map.of("PARTNER_TEST_ENV", "applied") + )); + + Assertions.assertTrue(result.isOk()); + Assertions.assertEquals("applied", result.getTotal()); + } + private void waitForBufferContains(StringBuilder buffer, String expected) throws InterruptedException { long deadline = System.currentTimeMillis() + 2000; while (System.currentTimeMillis() < deadline) { diff --git a/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionServiceTest.java b/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionServiceTest.java new file mode 100644 index 00000000..b3c73df6 --- /dev/null +++ b/Partner-Core/src/test/java/work/slhaf/partner/core/action/runner/execution/OriginExecutionServiceTest.java @@ -0,0 +1,78 @@ +package work.slhaf.partner.core.action.runner.execution; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import work.slhaf.partner.core.action.entity.MetaAction; +import work.slhaf.partner.core.action.runner.policy.ExecutionPolicy; +import work.slhaf.partner.core.action.runner.policy.ExecutionPolicyRegistry; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Set; + +class OriginExecutionServiceTest { + + private static String originalUserHome; + + @BeforeAll + static void prepareTestHome() throws IOException { + originalUserHome = System.getProperty("user.home"); + Path tempHome = Files.createTempDirectory("partner-test-home"); + System.setProperty("user.home", tempHome.toString()); + } + + @AfterAll + static void restoreUserHome() { + if (originalUserHome != null) { + System.setProperty("user.home", originalUserHome); + } + } + + @Test + void testOriginExecutionServiceAppliesExecutionPolicyEnvironment(@TempDir Path tempDir) throws IOException { + Path script = tempDir.resolve("print_env.py"); + Files.writeString(script, "import os\nprint(os.getenv('PARTNER_ORIGIN_TEST', ''), end='')\n"); + + ExecutionPolicy originalPolicy = new ExecutionPolicy( + ExecutionPolicy.Mode.DIRECT, + "direct", + ExecutionPolicy.Network.ENABLE, + true, + Map.of(), + null, + Set.of(), + Set.of() + ); + ExecutionPolicyRegistry.INSTANCE.updatePolicy(new ExecutionPolicy( + ExecutionPolicy.Mode.DIRECT, + "direct", + ExecutionPolicy.Network.ENABLE, + false, + Map.of("PARTNER_ORIGIN_TEST", "origin-applied"), + null, + Set.of(), + Set.of() + )); + + try { + var prepared = ExecutionPolicyRegistry.INSTANCE.prepare(List.of("python3", script.toString())); + Assertions.assertEquals("origin-applied", prepared.getEnvironment().get("PARTNER_ORIGIN_TEST")); + var directExec = CommandExecutionService.INSTANCE.exec(prepared); + Assertions.assertTrue(directExec.isOk()); + Assertions.assertEquals("origin-applied", directExec.getTotal()); + OriginExecutionService service = new OriginExecutionService(); + MetaAction metaAction = new MetaAction("run", false, "python3", MetaAction.Type.ORIGIN, script.toString()); + var response = service.run(metaAction); + Assertions.assertTrue(response.isOk()); + Assertions.assertEquals("origin-applied", response.getData()); + } finally { + ExecutionPolicyRegistry.INSTANCE.updatePolicy(originalPolicy); + } + } +}