Skip to content

Commit 7e6112d

Browse files
mask changes
1 parent f50fe65 commit 7e6112d

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

src/SVRfunctions.jl

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function fit(y::AbstractArray{T}, x::AbstractArray{T}; kw...) where {T <: Number
114114
return yp
115115
end
116116

117-
function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio::Number=0.1, repeats::Number=1, pm=nothing, keepcases::Union{BitArray,Nothing}=nothing, scale::Bool=false, ymin::Number=minimum(y), ymax::Number=maximum(y), quiet::Bool=false, veryquiet::Bool=true, total::Bool=false, rmse::Bool=true, callback::Function=(y::AbstractVector, y_pr::AbstractVector, pm::AbstractVector)->nothing, kw...)
117+
function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio_prediction::Number=0.1, repeats::Number=1, mask_prediction=nothing, keepcases::Union{BitArray,Nothing}=nothing, scale::Bool=false, ymin::Number=minimum(y), ymax::Number=maximum(y), quiet::Bool=false, veryquiet::Bool=true, total::Bool=false, rmse::Bool=true, callback::Function=(y::AbstractVector, y_pr::AbstractVector, mask_prediction::AbstractVector)->nothing, kw...)
118118
if !isnothing(keepcases)
119119
@assert length(keepcases) == size(x, 2)
120120
end
@@ -127,64 +127,64 @@ function fit_test(y::AbstractVector{Float64}, x::AbstractArray{Float64}; ratio::
127127
pma = Vector{Bool}(undef, 0)
128128
local y_pr
129129
for r in 1:repeats
130-
if repeats > 1 || isnothing(pm)
131-
pm = get_prediction_mask(length(y), ratio; keepcases=keepcases)
130+
if repeats > 1 || isnothing(mask_prediction)
131+
mask_prediction = get_prediction_mask(length(y), ratio_prediction; keepcases=keepcases)
132132
else
133-
@assert length(pm) == size(x, 2)
134-
@assert eltype(pm) <: Bool
133+
@assert length(mask_prediction) == size(x, 2)
134+
@assert eltype(mask_prediction) <: Bool
135135
end
136-
ic = sum(.!pm)
136+
ic = sum(.!mask_prediction)
137137
if !quiet && repeats == 1 && length(y) > ic
138-
@info("Training on $(ic) out of $(length(y)) (prediction ratio $ratio) ...")
138+
@info("Training on $(ic) out of $(length(y)) (prediction ratio_prediction $ratio_prediction) ...")
139139
end
140-
pmodel = train(a[.!pm], x[:,.!pm]; kw...)
140+
pmodel = train(a[.!mask_prediction], x[:,.!mask_prediction]; kw...)
141141
y_pr = predict(pmodel, x)
142142
freemodel(pmodel)
143143
if any(isnan.(y_pr))
144144
@warn("SVR output contains NaN's!")
145145
end
146146
if rmse
147-
m[r] = total ? rmse(y_pr, a) : rmse(y_pr[pm], a[pm])
147+
m[r] = total ? rmse(y_pr, a) : rmse(y_pr[mask_prediction], a[mask_prediction])
148148
else
149-
m[r] = total ? r2(y_pr, a) : r2(y_pr[pm], a[pm])
149+
m[r] = total ? r2(y_pr, a) : r2(y_pr[mask_prediction], a[mask_prediction])
150150
end
151151
if !veryquiet && repeats > 1
152152
println("Repeat $r: $(m[r])")
153153
end
154154
y_pra = vcat(y_pra, y_pr)
155155
ya = vcat(ya, y)
156-
pma = vcat(pma, pm)
156+
pma = vcat(pma, mask_prediction)
157157
end
158158
y_pra = y_pra * (ymax - ymin) .+ ymin
159159
callback(ya, y_pra, pma)
160160
y_pr = y_pr * (ymax - ymin) .+ ymin
161-
return y_pr, pm, Statistics.mean(m)
161+
return y_pr, mask_prediction, Statistics.mean(m)
162162
end
163-
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}; ratio::Number=0.1, kw...) where {T <: Number}
164-
y_pr, pm, rmse = fit_test(Float64.(y), Float64.(x); ratio=ratio, kw...)
165-
return T.(y_pr), pm, rmse
163+
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}; ratio_prediction::Number=0.1, kw...) where {T <: Number}
164+
y_pr, mask_prediction, rmse = fit_test(Float64.(y), Float64.(x); ratio_prediction=ratio_prediction, kw...)
165+
return T.(y_pr), mask_prediction, rmse
166166
end
167-
function fit_test(y::AbstractArray{T}, x::AbstractArray{T}; ratio::Number=0.1, pm=nothing, keepcases::Union{BitArray,Nothing}=nothing, kw...) where {T <: Number}
167+
function fit_test(y::AbstractArray{T}, x::AbstractArray{T}; ratio_prediction::Number=0.1, mask_prediction=nothing, keepcases::Union{BitArray,Nothing}=nothing, kw...) where {T <: Number}
168168
@assert size(y, 1) == size(x, 2)
169169
if !isnothing(keepcases)
170170
@assert length(keepcases) == size(x, 2)
171171
end
172-
if isnothing(pm)
173-
pm = get_prediction_mask(size(y, 1), ratio; keepcases=keepcases)
172+
if isnothing(mask_prediction)
173+
mask_prediction = get_prediction_mask(size(y, 1), ratio_prediction; keepcases=keepcases)
174174
end
175175
yp = similar(y)
176176
for i = 1:size(y, 2)
177-
yp[:,i], _, rmse = fit_test(vec(y[:,i]), x; ratio=ratio, pm=pm, kw...)
177+
yp[:,i], _, rmse = fit_test(vec(y[:,i]), x; ratio_prediction=ratio_prediction, mask_prediction=mask_prediction, kw...)
178178
end
179-
return yp, pm, rmse
179+
return yp, mask_prediction, rmse
180180
end
181-
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr=:gamma, rmse::Bool=true, check::Function=(v::AbstractVector)->nothing, kw...) where {T <: Number}
181+
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr=:gamma, rmse::Bool=true, check::Function=(v::AbstractVector)->nothing, kw...) where {T <: Number}
182182
@assert length(vattr) > 0
183-
@info("Grid search on $attr with prediction ratio $ratio ...")
183+
@info("Grid search on $attr with prediction ratio_prediction $ratio_prediction ...")
184184
ma = Vector{T}(undef, length(vattr))
185185
for (i, g) in enumerate(vattr)
186186
k = Dict(attr=>g)
187-
y_pr, pm, ma[i] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., quiet=true)
187+
y_pr, mask_prediction, ma[i] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., quiet=true)
188188
@info("$attr=>$g: $(ma[i])")
189189
end
190190
c = check(ma)
@@ -195,64 +195,64 @@ function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr::Union{Abstra
195195
m = ma[i]
196196
end
197197
k = Dict(attr=>vattr[i])
198-
return m, vattr[i], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
198+
return m, vattr[i], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
199199
end
200-
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr1=:gamma, attr2=:epsilon, rmse::Bool=true, kw...) where {T <: Number}
200+
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr1=:gamma, attr2=:epsilon, rmse::Bool=true, kw...) where {T <: Number}
201201
@assert length(vattr1) > 0
202202
@assert length(vattr2) > 0
203-
@info("Grid search on $attr1/$attr2 with prediction ratio $ratio ...")
203+
@info("Grid search on $attr1/$attr2 with prediction ratio_prediction $ratio_prediction ...")
204204
ma = Matrix{T}(undef, length(vattr1), length(vattr2))
205205
for (i, a1) in enumerate(vattr1)
206206
for (j, a2) in enumerate(vattr2)
207207
k = Dict(attr1=>a1, attr2=>a2)
208-
y_pr, pm, ma[i, j] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k...)
208+
y_pr, mask_prediction, ma[i, j] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k...)
209209
@info("$attr1=>$a1 $attr2=>$a2: $(ma[i,j])")
210210
end
211211
end
212212
m, i = rmse ? findmin(ma) : findmax(ma)
213213
k = Dict(attr1=>vattr1[i.I[1]], attr2=>vattr2[i.I[2]])
214-
return m, vattr1[i.I[1]], vattr2[i.I[2]], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
214+
return m, vattr1[i.I[1]], vattr2[i.I[2]], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
215215
end
216-
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}, vattr3::Union{AbstractVector,AbstractRange}; ratio::Number=0.1, attr1=:gamma, attr2=:epsilon, attr3=:C, rmse::Bool=true, kw...) where {T <: Number}
216+
function fit_test(y::AbstractVector{T}, x::AbstractArray{T}, vattr1::Union{AbstractVector,AbstractRange}, vattr2::Union{AbstractVector,AbstractRange}, vattr3::Union{AbstractVector,AbstractRange}; ratio_prediction::Number=0.1, attr1=:gamma, attr2=:epsilon, attr3=:C, rmse::Bool=true, kw...) where {T <: Number}
217217
@assert length(vattr1) > 0
218218
@assert length(vattr2) > 0
219219
@assert length(vattr3) > 0
220-
@info("Grid search on $attr1/$attr2/$attr3 with prediction ratio $ratio ...")
220+
@info("Grid search on $attr1/$attr2/$attr3 with prediction ratio_prediction $ratio_prediction ...")
221221
ma = Array{T}(undef, length(vattr1), length(vattr2), length(vattr3))
222222
for (i, a1) in enumerate(vattr1)
223223
for (j, a2) in enumerate(vattr2)
224224
for (k, a3) in enumerate(vattr3)
225225
kk = Dict(attr1=>a1, attr2=>a2, attr3=>a3)
226-
y_pr, pm, ma[i, j, k] = fit_test(y, x; ratio=ratio, rmse=rmse, kw..., kk...)
226+
y_pr, mask_prediction, ma[i, j, k] = fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., kk...)
227227
@info("$attr1=>$a1 $attr2=>$a2 $attr3=>$a3: $(ma[i,j,k])")
228228
end
229229
end
230230
end
231231
m, i = rmse ? findmin(ma) : findmax(ma)
232232
k = Dict(attr1=>vattr1[i.I[1]], attr2=>vattr2[i.I[2]], attr3=>vattr3[i.I[3]])
233-
return m, vattr1[i.I[1]], vattr2[i.I[2]], vattr3[i.I[3]], fit_test(y, x; ratio=ratio, rmse=rmse, kw..., k..., repeats=1)...
233+
return m, vattr1[i.I[1]], vattr2[i.I[2]], vattr3[i.I[3]], fit_test(y, x; ratio_prediction=ratio_prediction, rmse=rmse, kw..., k..., repeats=1)...
234234
end
235235

236236
"""
237237
Get prediction mask
238238
239239
$(DocumentFunction.documentfunction(get_prediction_mask;
240240
argtext=Dict("ns"=>"number of samples",
241-
"ratio"=>"prediction ratio")))
241+
"ratio_prediction"=>"prediction ratio_prediction")))
242242
243243
Return:
244244
245245
- prediction mask
246246
"""
247-
function get_prediction_mask(ns::Number, ratio::Number; keepcases::Union{AbstractVector,Nothing}=nothing, debug::Bool=false)
247+
function get_prediction_mask(ns::Number, ratio_prediction::Number; keepcases::Union{AbstractVector,Nothing}=nothing, debug::Bool=false)
248248
nsi = copy(ns)
249-
pm = trues(ns)
250-
ic = convert(Int64, ceil(ns * (1. - ratio)))
249+
mask_prediction = trues(ns)
250+
ic = convert(Int64, ceil(ns * (1. - ratio_prediction)))
251251
if !isnothing(keepcases)
252-
@assert length(keepcases) == length(pm)
252+
@assert length(keepcases) == length(mask_prediction)
253253
kn = sum(keepcases)
254254
if ic > kn && ns > kn
255-
pm[keepcases] .= false
255+
mask_prediction[keepcases] .= false
256256
ic -= kn
257257
nsi -= kn
258258
else
@@ -264,16 +264,16 @@ function get_prediction_mask(ns::Number, ratio::Number; keepcases::Union{Abstrac
264264
if !isnothing(keepcases) && ic > kn
265265
m = trues(nsi)
266266
m[ir] .= false
267-
pm[.!keepcases] .= m
267+
mask_prediction[.!keepcases] .= m
268268
else
269-
pm[ir] .= false
269+
mask_prediction[ir] .= false
270270
end
271271
end
272272
if debug
273-
@info("Number of cases for training: $(ns - sum(pm))")
274-
@info("Number of cases for prediction: $(sum(pm))")
273+
@info("Number of cases for training: $(ns - sum(mask_prediction))")
274+
@info("Number of cases for prediction: $(sum(mask_prediction))")
275275
end
276-
return pm
276+
return mask_prediction
277277
end
278278

279279
"""

0 commit comments

Comments
 (0)