Skip to content

Commit 60212b1

Browse files
committed
feat(backend/sdk):Add download_to_workspace option to dsl.importer
Signed-off-by: VaniHaripriya <[email protected]>
1 parent 4fdd6ac commit 60212b1

File tree

17 files changed

+881
-25
lines changed

17 files changed

+881
-25
lines changed

api/v2alpha1/go/pipelinespec/pipeline_spec.pb.go

Lines changed: 16 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/v2alpha1/pipeline_spec.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,9 @@ message PipelineDeploymentConfig {
873873

874874
// Whether or not import an artifact regardless it has been imported before.
875875
bool reimport = 5;
876+
877+
// If true, download artifact into the pipeline workspace.
878+
bool download_to_workspace = 7;
876879
}
877880

878881
// ResolverSpec resolves artifacts from historical metadata and returns them

backend/src/v2/compiler/argocompiler/importer.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,32 @@ func (c *workflowCompiler) addImporterTemplate() string {
9191
if value, ok := os.LookupEnv(PublishLogsEnvVar); ok {
9292
args = append(args, "--publish_logs", value)
9393
}
94+
// Add workspace volume only if the workflow defines a workspace PVC
95+
hasWorkspacePVC := false
96+
for _, pvc := range c.wf.Spec.VolumeClaimTemplates {
97+
if pvc.Name == workspaceVolumeName {
98+
hasWorkspacePVC = true
99+
break
100+
}
101+
}
102+
103+
var volumeMounts []k8score.VolumeMount
104+
var volumes []k8score.Volume
105+
if hasWorkspacePVC {
106+
volumeMounts = append(volumeMounts, k8score.VolumeMount{
107+
Name: workspaceVolumeName,
108+
MountPath: component.WorkspaceMountPath,
109+
})
110+
volumes = append(volumes, k8score.Volume{
111+
Name: workspaceVolumeName,
112+
VolumeSource: k8score.VolumeSource{
113+
PersistentVolumeClaim: &k8score.PersistentVolumeClaimVolumeSource{
114+
ClaimName: fmt.Sprintf("{{workflow.name}}-%s", workspaceVolumeName),
115+
},
116+
},
117+
})
118+
}
119+
94120
importerTemplate := &wfapi.Template{
95121
Name: name,
96122
Inputs: wfapi.Inputs{
@@ -102,13 +128,15 @@ func (c *workflowCompiler) addImporterTemplate() string {
102128
},
103129
},
104130
Container: &k8score.Container{
105-
Image: c.launcherImage,
106-
Command: c.launcherCommand,
107-
Args: args,
108-
EnvFrom: []k8score.EnvFromSource{metadataEnvFrom},
109-
Env: commonEnvs,
110-
Resources: driverResources,
131+
Image: c.launcherImage,
132+
Command: c.launcherCommand,
133+
Args: args,
134+
EnvFrom: []k8score.EnvFromSource{metadataEnvFrom},
135+
Env: commonEnvs,
136+
Resources: driverResources,
137+
VolumeMounts: volumeMounts,
111138
},
139+
Volumes: volumes,
112140
}
113141

114142
// If the apiserver is TLS-enabled, add the custom CA bundle to the importer template.

backend/src/v2/component/importer_launcher.go

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"os"
8+
"path/filepath"
79
"strings"
810

911
"github.com/kubeflow/pipelines/backend/src/common/util"
1012

13+
"github.com/kubeflow/pipelines/backend/src/v2/config"
1114
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
1215

1316
pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
@@ -126,12 +129,17 @@ func (l *ImportLauncher) Execute(ctx context.Context) (err error) {
126129
return err
127130
}
128131
ecfg := &metadata.ExecutionConfig{
129-
TaskName: l.task.GetTaskInfo().GetName(),
130-
PodName: l.launcherV2Options.PodName,
131-
PodUID: l.launcherV2Options.PodUID,
132-
Namespace: l.launcherV2Options.Namespace,
133-
ExecutionType: metadata.ImporterExecutionTypeName,
134-
ParentDagID: l.importerLauncherOptions.ParentDagID,
132+
TaskName: l.task.GetTaskInfo().GetName(),
133+
PodName: l.launcherV2Options.PodName,
134+
PodUID: l.launcherV2Options.PodUID,
135+
Namespace: l.launcherV2Options.Namespace,
136+
ExecutionType: func() metadata.ExecutionType {
137+
if l.importer.GetDownloadToWorkspace() {
138+
return metadata.ImporterWorkspaceExecutionTypeName
139+
}
140+
return metadata.ImporterExecutionTypeName
141+
}(),
142+
ParentDagID: l.importerLauncherOptions.ParentDagID,
135143
}
136144
createdExecution, err := l.metadataClient.CreateExecution(ctx, pipeline, ecfg)
137145
if err != nil {
@@ -253,15 +261,17 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
253261
}
254262

255263
if strings.HasPrefix(artifactUri, "oci://") {
264+
// OCI artifacts are not supported when workspace is used
265+
if l.importer.GetDownloadToWorkspace() {
266+
return nil, fmt.Errorf("importer workspace download does not support OCI registries")
267+
}
256268
artifactType, err := metadata.SchemaToArtifactType(schema)
257269
if err != nil {
258270
return nil, fmt.Errorf("converting schema to artifact type failed: %w", err)
259271
}
260-
261272
if *artifactType.Name != "system.Model" {
262273
return nil, fmt.Errorf("the %s artifact type does not support OCI registries", *artifactType.Name)
263274
}
264-
265275
return artifact, nil
266276
}
267277

@@ -283,6 +293,40 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
283293
}
284294
storeSessionInfoStr := string(storeSessionInfoJSON)
285295
artifact.CustomProperties["store_session_info"] = metadata.StringValue(storeSessionInfoStr)
296+
297+
// Download the artifact into the workspace
298+
if l.importer.GetDownloadToWorkspace() {
299+
bucketConfig, err := objectstore.ParseBucketConfigForArtifactURI(artifactUri)
300+
if err != nil {
301+
return nil, fmt.Errorf("failed to parse bucket config while downloading artifact into workspace with uri %q: %w", artifactUri, err)
302+
}
303+
// Resolve and attach session info from kfp-launcher config for the artifact provider
304+
if cfg, cfgErr := config.FromConfigMap(ctx, l.k8sClient, l.launcherV2Options.Namespace); cfgErr != nil {
305+
glog.Warningf("failed to load launcher config for workspace download: %v", cfgErr)
306+
} else if cfg != nil {
307+
if sess, sessErr := cfg.GetStoreSessionInfo(artifactUri); sessErr != nil {
308+
glog.Warningf("failed to resolve store session info for %q: %v", artifactUri, sessErr)
309+
} else {
310+
bucketConfig.SessionInfo = &sess
311+
}
312+
}
313+
blobKey, err := bucketConfig.KeyFromURI(artifactUri)
314+
if err != nil {
315+
return nil, fmt.Errorf("failed to derive blob key from uri %q while downloading artifact into workspace: %w", artifactUri, err)
316+
}
317+
workspaceRoot := filepath.Join(WorkspaceMountPath, ".artifacts")
318+
if err := os.MkdirAll(workspaceRoot, 0755); err != nil {
319+
return nil, fmt.Errorf("failed to create workspace directory %q: %w", workspaceRoot, err)
320+
}
321+
bucket, err := objectstore.OpenBucket(ctx, l.k8sClient, l.launcherV2Options.Namespace, bucketConfig)
322+
if err != nil {
323+
return nil, fmt.Errorf("failed to open bucket for uri %q: %w", artifactUri, err)
324+
}
325+
defer bucket.Close()
326+
if err := objectstore.DownloadBlob(ctx, bucket, workspaceRoot, blobKey); err != nil {
327+
return nil, fmt.Errorf("failed to download artifact to workspace: %w", err)
328+
}
329+
}
286330
return artifact, nil
287331
}
288332

backend/src/v2/component/launcher_v2.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,12 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor
713713
for _, artifact := range artifactList.Artifacts {
714714
// Iterating through the artifact list allows for collected artifacts to be properly consumed.
715715
inputArtifact := artifact
716+
// Skip downloading if the artifact is flagged as already present in the workspace
717+
if inputArtifact.GetMetadata() != nil {
718+
if v, ok := inputArtifact.GetMetadata().GetFields()["_kfp_workspace"]; ok && v.GetBoolValue() {
719+
continue
720+
}
721+
}
716722
localPath, err := LocalPathForURI(inputArtifact.Uri)
717723
if err != nil {
718724
glog.Warningf("Input Artifact %q does not have a recognized storage URI %q. Skipping downloading to local path.", name, inputArtifact.Uri)
@@ -856,6 +862,22 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma
856862
key := fmt.Sprintf(`{{$.inputs.artifacts['%s'].uri}}`, name)
857863
placeholders[key] = inputArtifact.Uri
858864

865+
// If the artifact is marked as already in the workspace, map the workspace path.
866+
if inputArtifact.GetMetadata() != nil {
867+
if v, ok := inputArtifact.GetMetadata().GetFields()["_kfp_workspace"]; ok && v.GetBoolValue() {
868+
bucketConfig, err := objectstore.ParseBucketConfigForArtifactURI(inputArtifact.Uri)
869+
if err == nil {
870+
blobKey, err := bucketConfig.KeyFromURI(inputArtifact.Uri)
871+
if err == nil {
872+
localPath := filepath.Join(WorkspaceMountPath, ".artifacts", blobKey)
873+
key = fmt.Sprintf(`{{$.inputs.artifacts['%s'].path}}`, name)
874+
placeholders[key] = localPath
875+
continue
876+
}
877+
}
878+
}
879+
}
880+
859881
localPath, err := LocalPathForURI(inputArtifact.Uri)
860882
if err != nil {
861883
// Input Artifact does not have a recognized storage URI

backend/src/v2/component/launcher_v2_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"errors"
2121
"io"
2222
"os"
23+
"path/filepath"
2324
"testing"
2425

2526
"github.com/kubeflow/pipelines/backend/src/v2/cacheutils"
@@ -176,6 +177,29 @@ func Test_executeV2_publishLogs(t *testing.T) {
176177
}
177178
}
178179

180+
func Test_getPlaceholders_WorkspaceArtifactPath(t *testing.T) {
181+
execIn := &pipelinespec.ExecutorInput{
182+
Inputs: &pipelinespec.ExecutorInput_Inputs{
183+
Artifacts: map[string]*pipelinespec.ArtifactList{
184+
"data": {
185+
Artifacts: []*pipelinespec.RuntimeArtifact{
186+
{Uri: "minio://mlpipeline/sample/sample.txt", Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"_kfp_workspace": structpb.NewBoolValue(true)}}},
187+
},
188+
},
189+
},
190+
},
191+
}
192+
ph, err := getPlaceholders(execIn)
193+
if err != nil {
194+
t.Fatalf("getPlaceholders error: %v", err)
195+
}
196+
got := ph["{{$.inputs.artifacts['data'].path}}"]
197+
want := filepath.Join(WorkspaceMountPath, ".artifacts", "sample", "sample.txt")
198+
if got != want {
199+
t.Fatalf("placeholder path mismatch: got=%q want=%q", got, want)
200+
}
201+
}
202+
179203
func Test_executorInput_compileCmdAndArgs(t *testing.T) {
180204
executorInputJSON := `{
181205
"inputs": {

backend/src/v2/driver/driver_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,68 @@ func TestWorkspaceMount_PassthroughVolumes_ApplyAndCapture(t *testing.T) {
11891189
}
11901190
}
11911191

1192+
func TestWorkspaceMount_TriggeredByArtifactMetadata(t *testing.T) {
1193+
proxy.InitializeConfigWithEmptyForTests()
1194+
containerSpec := &pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec{Image: "python:3.9"}
1195+
componentSpec := &pipelinespec.ComponentSpec{
1196+
TaskConfigPassthroughs: []*pipelinespec.TaskConfigPassthrough{
1197+
{
1198+
Field: pipelinespec.TaskConfigPassthroughType_KUBERNETES_VOLUMES,
1199+
ApplyToTask: true,
1200+
},
1201+
},
1202+
}
1203+
1204+
// Build an ExecutorInput that does NOT reference workspace path in params,
1205+
// but contains an input artifact marked as already in workspace.
1206+
execInput := &pipelinespec.ExecutorInput{
1207+
Inputs: &pipelinespec.ExecutorInput_Inputs{
1208+
Artifacts: map[string]*pipelinespec.ArtifactList{
1209+
"data": {
1210+
Artifacts: []*pipelinespec.RuntimeArtifact{
1211+
{
1212+
Uri: "minio://mlpipeline/sample/sample.txt",
1213+
Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{
1214+
"_kfp_workspace": structpb.NewBoolValue(true),
1215+
}},
1216+
},
1217+
},
1218+
},
1219+
},
1220+
},
1221+
}
1222+
1223+
taskCfg := &TaskConfig{}
1224+
podSpec, err := initPodSpecPatch(
1225+
containerSpec, componentSpec, execInput,
1226+
27, "test", "run", "my-run-name", "1", "false", "false", taskCfg, false, false, "",
1227+
)
1228+
assert.Nil(t, err)
1229+
1230+
// Expect workspace volume mounted
1231+
if assert.Len(t, podSpec.Volumes, 1) {
1232+
assert.Equal(t, "kfp-workspace", podSpec.Volumes[0].Name)
1233+
if assert.NotNil(t, podSpec.Volumes[0].PersistentVolumeClaim) {
1234+
assert.Equal(t, "my-run-name-kfp-workspace", podSpec.Volumes[0].PersistentVolumeClaim.ClaimName)
1235+
}
1236+
}
1237+
if assert.Len(t, podSpec.Containers, 1) {
1238+
if assert.Len(t, podSpec.Containers[0].VolumeMounts, 1) {
1239+
assert.Equal(t, "kfp-workspace", podSpec.Containers[0].VolumeMounts[0].Name)
1240+
assert.Equal(t, "/kfp-workspace", podSpec.Containers[0].VolumeMounts[0].MountPath)
1241+
}
1242+
}
1243+
1244+
// Also captured to TaskConfig
1245+
if assert.Len(t, taskCfg.Volumes, 1) {
1246+
assert.Equal(t, "kfp-workspace", taskCfg.Volumes[0].Name)
1247+
}
1248+
if assert.Len(t, taskCfg.VolumeMounts, 1) {
1249+
assert.Equal(t, "kfp-workspace", taskCfg.VolumeMounts[0].Name)
1250+
assert.Equal(t, "/kfp-workspace", taskCfg.VolumeMounts[0].MountPath)
1251+
}
1252+
}
1253+
11921254
func Test_initPodSpecPatch_TaskConfig_Env_Passthrough_CaptureOnly(t *testing.T) {
11931255
proxy.InitializeConfigWithEmptyForTests()
11941256
containerSpec := &pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec{

backend/src/v2/driver/resolve.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ import (
3333

3434
var ErrResolvedParameterNull = errors.New("the resolved input parameter is null")
3535

36+
// setWorkspaceFlag sets the _kfp_workspace metadata flag on the provided
37+
// runtime artifact when the producing execution is an ImporterWorkspace.
38+
func setWorkspaceFlag(execution *metadata.Execution, runtimeArtifact *pipelinespec.RuntimeArtifact) {
39+
if execution.GetExecution().GetType() == string(metadata.ImporterWorkspaceExecutionTypeName) {
40+
if runtimeArtifact.Metadata == nil {
41+
runtimeArtifact.Metadata = &structpb.Struct{Fields: map[string]*structpb.Value{}}
42+
}
43+
runtimeArtifact.Metadata.Fields["_kfp_workspace"] = structpb.NewBoolValue(true)
44+
}
45+
}
46+
3647
// resolveUpstreamOutputsConfig is just a config struct used to store the input
3748
// parameters of the resolveUpstreamParameters and resolveUpstreamArtifacts
3849
// functions.
@@ -763,6 +774,8 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) (*pipelinespec.A
763774
if err != nil {
764775
cfg.err(err)
765776
}
777+
// If produced by workspace importer, Set _kfp_workspace=True
778+
setWorkspaceFlag(currentTask, runtimeArtifact)
766779
// Base case
767780
return &pipelinespec.ArtifactList{
768781
Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact},
@@ -980,6 +993,8 @@ func collectContainerOutput(
980993
if err != nil {
981994
return nil, nil, cfg.err(err)
982995
}
996+
// If produced by workspace importer, Set _kfp_workspace=True
997+
setWorkspaceFlag(currentTask, artifact)
983998
glog.V(4).Infof("runtimeArtifact: %v", artifact)
984999
} else {
9851000
_, outputParameters, err := currentTask.GetParameters()

0 commit comments

Comments
 (0)