Skip to content

Commit 4af58d4

Browse files
authored
Configure service name env variables on API server (#12463)
* Allow cache endpoint configurable with pod namespace. Signed-off-by: alyssacgoins <[email protected]> * update compiler tests. Signed-off-by: alyssacgoins <[email protected]> --------- Signed-off-by: alyssacgoins <[email protected]>
1 parent 0302125 commit 4af58d4

File tree

143 files changed

+1143
-70
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

143 files changed

+1143
-70
lines changed

backend/src/apiserver/common/config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ const (
3636
CaBundleSecretName string = "CABUNDLE_SECRET_NAME"
3737
RequireNamespaceForPipelines string = "REQUIRE_NAMESPACE_FOR_PIPELINES"
3838
CompiledPipelineSpecPatch string = "COMPILED_PIPELINE_SPEC_PATCH"
39+
MLPipelineServiceName string = "ML_PIPELINE_SERVICE_NAME"
40+
MetadataServiceName string = "METADATA_SERVICE_NAME"
3941
)
4042

4143
func IsPipelineVersionUpdatedByDefault() bool {
@@ -112,6 +114,14 @@ func GetPodNamespace() string {
112114
return GetStringConfigWithDefault(PodNamespace, DefaultPodNamespace)
113115
}
114116

117+
func GetMLPipelineServiceName() string {
118+
return GetStringConfigWithDefault(MLPipelineServiceName, DefaultMLPipelineServiceName)
119+
}
120+
121+
func GetMetadataServiceName() string {
122+
return GetStringConfigWithDefault(MetadataServiceName, DefaultMetadataServiceName)
123+
}
124+
115125
func GetBoolFromStringWithDefault(value string, defaultValue bool) bool {
116126
boolVal, err := strconv.ParseBool(value)
117127
if err != nil {

backend/src/apiserver/common/const.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,8 @@ const (
7575
const (
7676
DefaultPodNamespace string = "kubeflow"
7777
)
78+
79+
const (
80+
DefaultMLPipelineServiceName string = "ml-pipeline"
81+
DefaultMetadataServiceName string = "metadata-grpc-service"
82+
)

backend/src/v2/cacheutils/cache.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ const (
2626
MaxClientGRPCMessageSize = 100 * 1024 * 1024
2727
// The endpoint uses Kubernetes service DNS name with namespace:
2828
// https://kubernetes.io/docs/concepts/services-networking/service/#dns
29-
defaultKfpApiEndpoint = "ml-pipeline.kubeflow:8887"
3029
)
3130

3231
type Client interface {
@@ -77,7 +76,7 @@ type client struct {
7776
var _ Client = &client{}
7877

7978
// NewClient creates a Client.
80-
func NewClient(cacheDisabled bool, tlsCfg *tls.Config) (Client, error) {
79+
func NewClient(mlPipelineServerAddress string, mlPipelineServerPort string, cacheDisabled bool, tlsCfg *tls.Config) (Client, error) {
8180
if cacheDisabled {
8281
return &disabledCacheClient{}, nil
8382
}
@@ -86,9 +85,10 @@ func NewClient(cacheDisabled bool, tlsCfg *tls.Config) (Client, error) {
8685
if tlsCfg != nil {
8786
creds = credentials.NewTLS(tlsCfg)
8887
}
89-
glog.Infof("Connecting to cache endpoint %s", defaultKfpApiEndpoint)
88+
cacheEndPoint := mlPipelineServerAddress + ":" + mlPipelineServerPort
89+
glog.Infof("Connecting to cache endpoint %s", cacheEndPoint)
9090
conn, err := grpc.NewClient(
91-
defaultKfpApiEndpoint,
91+
cacheEndPoint,
9292
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
9393
grpc.WithTransportCredentials(creds),
9494
)

backend/src/v2/cacheutils/cache_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func TestGenerateCacheKey(t *testing.T) {
218218
wantErr: false,
219219
},
220220
}
221-
cacheClient, err := NewClient(false, &tls.Config{})
221+
cacheClient, err := NewClient("ml-pipeline.kubeflow", "8887", false, &tls.Config{})
222222
require.NoError(t, err)
223223
for _, test := range tests {
224224
t.Run(test.name, func(t *testing.T) {
@@ -339,7 +339,7 @@ func TestGenerateFingerPrint(t *testing.T) {
339339
fingerPrint: "3d9a2a778fa3174c6cfc6e639c507c265b5f21ef6e5b1dd70b236462cc6da464",
340340
},
341341
}
342-
cacheClient, err := NewClient(false, &tls.Config{})
342+
cacheClient, err := NewClient("ml-pipeline.kubeflow", "8887", false, &tls.Config{})
343343
require.NoError(t, err)
344344
for _, test := range tests {
345345
t.Run(test.name, func(t *testing.T) {
@@ -409,7 +409,7 @@ func TestGenerateFingerPrint_ConsidersPVCNames(t *testing.T) {
409409
},
410410
}
411411

412-
cacheClient, err := NewClient(false, &tls.Config{})
412+
cacheClient, err := NewClient("ml-pipeline.kubeflow", "8887", false, &tls.Config{})
413413
require.NoError(t, err)
414414

415415
baseFP, err := cacheClient.GenerateFingerPrint(base)

backend/src/v2/client_manager/client_manager.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ type ClientManager struct {
2929
}
3030

3131
type Options struct {
32-
MLMDServerAddress string
33-
MLMDServerPort string
34-
CacheDisabled bool
35-
CaCertPath string
36-
MLMDTLSEnabled bool
32+
MLPipelineServerAddress string
33+
MLPipelineServerPort string
34+
MLMDServerAddress string
35+
MLMDServerPort string
36+
CacheDisabled bool
37+
CaCertPath string
38+
MLMDTLSEnabled bool
3739
}
3840

3941
// NewClientManager creates and Init a new instance of ClientManager.
@@ -76,7 +78,7 @@ func (cm *ClientManager) init(opts *Options) error {
7678
if err != nil {
7779
return err
7880
}
79-
cacheClient, err := initCacheClient(opts.CacheDisabled, tlsCfg)
81+
cacheClient, err := initCacheClient(opts.MLPipelineServerAddress, opts.MLPipelineServerPort, opts.CacheDisabled, tlsCfg)
8082
if err != nil {
8183
return err
8284
}
@@ -102,6 +104,6 @@ func initMetadataClient(address string, port string, tlsCfg *tls.Config) (metada
102104
return metadata.NewClient(address, port, tlsCfg)
103105
}
104106

105-
func initCacheClient(cacheDisabled bool, tlsCfg *tls.Config) (cacheutils.Client, error) {
106-
return cacheutils.NewClient(cacheDisabled, tlsCfg)
107+
func initCacheClient(mlPipelineServerAddress string, mlPipelineServerPort string, cacheDisabled bool, tlsCfg *tls.Config) (cacheutils.Client, error) {
108+
return cacheutils.NewClient(mlPipelineServerAddress, mlPipelineServerPort, cacheDisabled, tlsCfg)
107109
}

backend/src/v2/cmd/driver/main.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ var (
6969
k8sExecConfigJson = flag.String("kubernetes_config", "{}", "kubernetes executor config")
7070

7171
// config
72-
mlmdServerAddress = flag.String("mlmd_server_address", "", "MLMD server address")
73-
mlmdServerPort = flag.String("mlmd_server_port", "", "MLMD server port")
72+
mlPipelineServerAddress = flag.String("ml_pipeline_server_address", "ml-pipeline", "The name of the ML pipeline API server address.")
73+
mlPipelineServerPort = flag.String("ml_pipeline_server_port", "8887", "The port of the ML pipeline API server.")
74+
mlmdServerAddress = flag.String("mlmd_server_address", "", "MLMD server address")
75+
mlmdServerPort = flag.String("mlmd_server_port", "", "MLMD server port")
7476

7577
// output paths
7678
executionIDPath = flag.String("execution_id_path", "", "Exeucution ID output path")
@@ -190,7 +192,7 @@ func drive() (err error) {
190192
if err != nil {
191193
return err
192194
}
193-
cacheClient, err := cacheutils.NewClient(*cacheDisabledFlag, tlsCfg)
195+
cacheClient, err := cacheutils.NewClient(*mlPipelineServerAddress, *mlPipelineServerPort, *cacheDisabledFlag, tlsCfg)
194196
if err != nil {
195197
return err
196198
}

backend/src/v2/cmd/launcher-v2/main.go

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,28 @@ import (
2828

2929
// TODO: use https://github.com/spf13/cobra as a framework to create more complex CLI tools with subcommands.
3030
var (
31-
copy = flag.String("copy", "", "copy this binary to specified destination path")
32-
pipelineName = flag.String("pipeline_name", "", "pipeline context name")
33-
runID = flag.String("run_id", "", "pipeline run uid")
34-
parentDagID = flag.Int64("parent_dag_id", 0, "parent DAG execution ID")
35-
executorType = flag.String("executor_type", "container", "The type of the ExecutorSpec")
36-
executionID = flag.Int64("execution_id", 0, "Execution ID of this task.")
37-
executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.")
38-
componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.")
39-
importerSpecJSON = flag.String("importer_spec", "", "The JSON-encoded ImporterSpec.")
40-
taskSpecJSON = flag.String("task_spec", "", "The JSON-encoded TaskSpec.")
41-
podName = flag.String("pod_name", "", "Kubernetes Pod name.")
42-
podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.")
43-
mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.")
44-
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
45-
logLevel = flag.String("log_level", "1", "The verbosity level to log.")
46-
publishLogs = flag.String("publish_logs", "true", "Whether to publish component logs to the object store")
47-
cacheDisabledFlag = flag.Bool("cache_disabled", false, "Disable cache globally.")
48-
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate to trust on connections to the ML pipeline API server and metadata server.")
49-
mlPipelineTLSEnabled = flag.Bool("ml_pipeline_tls_enabled", false, "Set to true if mlpipeline API server serves over TLS.")
50-
metadataTLSEnabled = flag.Bool("metadata_tls_enabled", false, "Set to true if MLMD serves over TLS.")
31+
copy = flag.String("copy", "", "copy this binary to specified destination path")
32+
pipelineName = flag.String("pipeline_name", "", "pipeline context name")
33+
runID = flag.String("run_id", "", "pipeline run uid")
34+
parentDagID = flag.Int64("parent_dag_id", 0, "parent DAG execution ID")
35+
executorType = flag.String("executor_type", "container", "The type of the ExecutorSpec")
36+
executionID = flag.Int64("execution_id", 0, "Execution ID of this task.")
37+
executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.")
38+
componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.")
39+
importerSpecJSON = flag.String("importer_spec", "", "The JSON-encoded ImporterSpec.")
40+
taskSpecJSON = flag.String("task_spec", "", "The JSON-encoded TaskSpec.")
41+
podName = flag.String("pod_name", "", "Kubernetes Pod name.")
42+
podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.")
43+
mlPipelineServerAddress = flag.String("ml_pipeline_server_address", "ml-pipeline.kubeflow", "The name of the ML pipeline API server address.")
44+
mlPipelineServerPort = flag.String("ml_pipeline_server_port", "8887", "The port of the ML pipeline API server.")
45+
mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.")
46+
mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.")
47+
logLevel = flag.String("log_level", "1", "The verbosity level to log.")
48+
publishLogs = flag.String("publish_logs", "true", "Whether to publish component logs to the object store")
49+
cacheDisabledFlag = flag.Bool("cache_disabled", false, "Disable cache globally.")
50+
caCertPath = flag.String("ca_cert_path", "", "The path to the CA certificate to trust on connections to the ML pipeline API server and metadata server.")
51+
mlPipelineTLSEnabled = flag.Bool("ml_pipeline_tls_enabled", false, "Set to true if mlpipeline API server serves over TLS.")
52+
metadataTLSEnabled = flag.Bool("metadata_tls_enabled", false, "Set to true if MLMD serves over TLS.")
5153
)
5254

5355
func main() {
@@ -79,18 +81,20 @@ func run() error {
7981
}
8082

8183
launcherV2Opts := &component.LauncherV2Options{
82-
Namespace: namespace,
83-
PodName: *podName,
84-
PodUID: *podUID,
85-
MLMDServerAddress: *mlmdServerAddress,
86-
MLMDServerPort: *mlmdServerPort,
87-
PipelineName: *pipelineName,
88-
RunID: *runID,
89-
PublishLogs: *publishLogs,
90-
CacheDisabled: *cacheDisabledFlag,
91-
MLPipelineTLSEnabled: *mlPipelineTLSEnabled,
92-
MLMDTLSEnabled: *metadataTLSEnabled,
93-
CaCertPath: *caCertPath,
84+
Namespace: namespace,
85+
PodName: *podName,
86+
PodUID: *podUID,
87+
MLPipelineServerAddress: *mlPipelineServerAddress,
88+
MLPipelineServerPort: *mlPipelineServerPort,
89+
MLMDServerAddress: *mlmdServerAddress,
90+
MLMDServerPort: *mlmdServerPort,
91+
PipelineName: *pipelineName,
92+
RunID: *runID,
93+
PublishLogs: *publishLogs,
94+
CacheDisabled: *cacheDisabledFlag,
95+
MLPipelineTLSEnabled: *mlPipelineTLSEnabled,
96+
MLMDTLSEnabled: *metadataTLSEnabled,
97+
CaCertPath: *caCertPath,
9498
}
9599

96100
switch *executorType {
@@ -110,11 +114,13 @@ func run() error {
110114
return nil
111115
case "container":
112116
clientOptions := &client_manager.Options{
113-
MLMDServerAddress: launcherV2Opts.MLMDServerAddress,
114-
MLMDServerPort: launcherV2Opts.MLMDServerPort,
115-
CacheDisabled: launcherV2Opts.CacheDisabled,
116-
MLMDTLSEnabled: launcherV2Opts.MLMDTLSEnabled,
117-
CaCertPath: launcherV2Opts.CaCertPath,
117+
MLPipelineServerAddress: launcherV2Opts.MLPipelineServerAddress,
118+
MLPipelineServerPort: launcherV2Opts.MLPipelineServerPort,
119+
MLMDServerAddress: launcherV2Opts.MLMDServerAddress,
120+
MLMDServerPort: launcherV2Opts.MLMDServerPort,
121+
CacheDisabled: launcherV2Opts.CacheDisabled,
122+
MLMDTLSEnabled: launcherV2Opts.MLMDTLSEnabled,
123+
CaCertPath: launcherV2Opts.CaCertPath,
118124
}
119125
clientManager, err := client_manager.NewClientManager(clientOptions)
120126
if err != nil {

backend/src/v2/compiler/argocompiler/container.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"strconv"
2222
"strings"
2323

24+
"github.com/kubeflow/pipelines/backend/src/v2/config"
2425
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
2526
"google.golang.org/protobuf/encoding/protojson"
2627

@@ -207,8 +208,10 @@ func (c *workflowCompiler) addContainerDriverTemplate() string {
207208
"--http_proxy", proxy.GetConfig().GetHttpProxy(),
208209
"--https_proxy", proxy.GetConfig().GetHttpsProxy(),
209210
"--no_proxy", proxy.GetConfig().GetNoProxy(),
210-
"--mlmd_server_address", metadata.DefaultConfig().Address,
211-
"--mlmd_server_port", metadata.DefaultConfig().Port,
211+
"--ml_pipeline_server_address", config.GetMLPipelineServerConfig().Address,
212+
"--ml_pipeline_server_port", config.GetMLPipelineServerConfig().Port,
213+
"--mlmd_server_address", metadata.GetMetadataConfig().Address,
214+
"--mlmd_server_port", metadata.GetMetadataConfig().Port,
212215
}
213216
if c.cacheDisabled {
214217
args = append(args, "--cache_disabled")

backend/src/v2/compiler/argocompiler/dag.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"strings"
2121

2222
"github.com/kubeflow/pipelines/backend/src/apiserver/config/proxy"
23+
"github.com/kubeflow/pipelines/backend/src/v2/config"
2324
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
2425

2526
wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
@@ -572,8 +573,10 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
572573
"--http_proxy", proxy.GetConfig().GetHttpProxy(),
573574
"--https_proxy", proxy.GetConfig().GetHttpsProxy(),
574575
"--no_proxy", proxy.GetConfig().GetNoProxy(),
575-
"--mlmd_server_address", metadata.DefaultConfig().Address,
576-
"--mlmd_server_port", metadata.DefaultConfig().Port,
576+
"--ml_pipeline_server_address", config.GetMLPipelineServerConfig().Address,
577+
"--ml_pipeline_server_port", config.GetMLPipelineServerConfig().Port,
578+
"--mlmd_server_address", metadata.GetMetadataConfig().Address,
579+
"--mlmd_server_port", metadata.GetMetadataConfig().Port,
577580
}
578581
if c.cacheDisabled {
579582
args = append(args, "--cache_disabled")

backend/src/v2/compiler/argocompiler/importer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ func (c *workflowCompiler) addImporterTemplate() string {
7979
fmt.Sprintf("$(%s)", component.EnvPodName),
8080
"--pod_uid",
8181
fmt.Sprintf("$(%s)", component.EnvPodUID),
82-
"--mlmd_server_address", metadata.DefaultConfig().Address,
83-
"--mlmd_server_port", metadata.DefaultConfig().Port,
82+
"--mlmd_server_address", metadata.GetMetadataConfig().Address,
83+
"--mlmd_server_port", metadata.GetMetadataConfig().Port,
8484
}
8585
if c.cacheDisabled {
8686
args = append(args, "--cache_disabled")

0 commit comments

Comments
 (0)