Skip to content

Commit bf4eef9

Browse files
zhoguyukirora
andauthored
ModelProxy supports the "jobType" field of inference jobs (#118)
* support job Type and inference job parameters * update * update * update * update * update --------- Co-authored-by: Yuting Jiang <[email protected]>
1 parent a3ba466 commit bf4eef9

File tree

9 files changed

+170
-123
lines changed

9 files changed

+170
-123
lines changed

src/model-proxy/config/model-proxy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ service_type: "common"
55

66
port: 9999
77
retry: 5
8-
modelkey: "123"
8+
concurrency: 10

src/model-proxy/deploy/model-proxy.yaml.template

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ spec:
2727
args:
2828
- "--port={{ cluster_cfg['model-proxy']['port'] }}"
2929
- "--retry={{ cluster_cfg['model-proxy']['retry'] }}"
30-
- "--modelkey={{ cluster_cfg['model-proxy']['modelkey'] }}"
3130
- "--logdir=/usr/local/ltp/model-proxy/logs"
3231
env:
3332
- name: REST_SERVER_URI
3433
value: {{ cluster_cfg["rest-server"]["uri"] }}
34+
- name: CONCURRENCY
35+
value: "{{ cluster_cfg['model-proxy']['concurrency'] }}"
3536
volumeMounts:
3637
{%- if cluster_cfg['model-proxy']['log_pvc'] %}
3738
- name: model-proxy-log-storage

src/model-proxy/src/main.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@ var (
1515
port int
1616
maxRetries int = 5 // default value
1717
logFileDir string
18-
modelKey string
1918
)
2019

2120
func init() {
2221
flag.IntVar(&port, "port", 9999, "port for the proxy server")
2322
flag.IntVar(&maxRetries, "retry", 5, "max retries for the request to the model server")
2423
flag.StringVar(&logFileDir, "logdir", "./logs", "path to the log file directory")
25-
flag.StringVar(&modelKey, "modelkey", "", "model key for requesting model serving jobs")
2624
}
2725

2826
func main() {
@@ -33,7 +31,6 @@ func main() {
3331
Host: "0.0.0.0",
3432
Port: port,
3533
MaxRetries: maxRetries,
36-
ModelKey: modelKey,
3734
},
3835
Log: &types.Log{
3936
LogStorage: &types.LogStorage{

src/model-proxy/src/proxy/authenticator.go

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"log"
88
"net/http"
99
"strings"
10+
11+
"modelproxy/types"
1012
)
1113

1214
// obfuscateToken returns a truncated identifier for safely logging tokens.
@@ -18,61 +20,57 @@ func obfuscateToken(token string) string {
1820
}
1921

2022
type RestServerAuthenticator struct {
21-
// rest-server token => model names => model urls
22-
tokenToModels map[string]map[string][]string
23-
modelKey string
23+
// rest-server token => model names => model service list
24+
tokenToModels map[string]map[string][]*types.BaseSpec
2425
}
2526

26-
func NewRestServerAuthenticator(tokenToModels map[string]map[string][]string, modelKey string) *RestServerAuthenticator {
27-
if tokenToModels == nil {
28-
tokenToModels = make(map[string]map[string][]string)
29-
}
27+
func NewRestServerAuthenticator() *RestServerAuthenticator {
3028
return &RestServerAuthenticator{
31-
tokenToModels: tokenToModels,
32-
modelKey: modelKey,
29+
tokenToModels: make(map[string]map[string][]*types.BaseSpec),
3330
}
3431
}
3532

36-
func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Url map[string][]string) {
33+
// UpdateTokenModels updates the model mapping for a given token
34+
func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Service map[string][]*types.BaseSpec) {
3735
if ra.tokenToModels == nil {
38-
ra.tokenToModels = make(map[string]map[string][]string)
36+
ra.tokenToModels = make(map[string]map[string][]*types.BaseSpec)
3937
}
40-
ra.tokenToModels[token] = model2Url
38+
ra.tokenToModels[token] = model2Service
4139
}
4240

4341
// Check if the request is authenticated and return the available model urls
44-
func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody map[string]interface{}) (bool, []string) {
42+
func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody map[string]interface{}) (bool, []*types.BaseSpec) {
4543
token := req.Header.Get("Authorization")
4644
token = strings.Replace(token, "Bearer ", "", 1)
4745
// read request body
4846
model, ok := reqBody["model"].(string)
4947
if !ok {
5048
log.Printf("[-] Error: 'model' field missing or not a string in request body")
51-
return false, []string{}
49+
return false, nil
5250
}
5351
availableModels, ok := ra.tokenToModels[token]
5452
if !ok {
5553
// request to RestServer to get the models
5654
log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token))
57-
availableModels, err := GetJobModelsMapping(req, ra.modelKey)
55+
availableModels, err := GetJobModelsMapping(req)
5856
if err != nil {
5957
log.Printf("[-] Error: failed to get models for token %s: %v\n", obfuscateToken(token), err)
60-
return false, []string{}
58+
return false, nil
6159
}
6260
ra.tokenToModels[token] = availableModels
6361
}
6462
if len(availableModels) == 0 {
6563
log.Printf("[-] Error: no models found")
66-
return false, []string{}
64+
return false, nil
6765
}
6866
if model == "" {
6967
log.Printf("[-] Error: model is empty")
70-
return false, []string{}
68+
return false, nil
7169
}
7270
for m, v := range availableModels {
7371
if m == model {
7472
return true, v
7573
}
7674
}
77-
return false, []string{}
75+
return false, nil
7876
}

src/model-proxy/src/proxy/load_balancer.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ func NewUrlPoller(url string, bsl types.BaseSpecList) *UrlPoller {
2828
}
2929
}
3030

31-
func NewUrlPollerWithKey(url string, modelUrls []string, modelKey string) *UrlPoller {
32-
if len(modelUrls) == 0 {
31+
func NewUrlPollerWithKey(url string, modelServices []*types.BaseSpec) *UrlPoller {
32+
if len(modelServices) == 0 {
3333
return nil
3434
}
35-
bsl := make(types.BaseSpecList, 0, len(modelUrls))
36-
for _, v := range modelUrls {
35+
bsl := make(types.BaseSpecList, 0, len(modelServices))
36+
for _, v := range modelServices {
3737
bsl = append(bsl, &types.BaseSpecStatistic{
3838
BaseSpec: &types.BaseSpec{
39-
URL: v,
40-
Key: modelKey,
39+
URL: v.URL,
40+
Key: v.Key,
4141
},
4242
})
4343
}

0 commit comments

Comments
 (0)