Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 19 additions & 84 deletions src/generic/UnivPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,93 +783,26 @@ end
#
###############################################################################

function evaluate(a::UnivPoly{T}, A::Vector{T}) where {T <: RingElem}
R = base_ring(a)
function evaluate(a::UnivPoly, A::Vector{<:Union{NCRingElem, RingElement}})
a2 = data(a)
varidx = var_indices(a2)
isempty(varidx) && return constant_coefficient(a2)
isempty(A) && error("Number of variables does not match number of values")
vals = zeros(parent(A[1]), nvars(parent(a2)))
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
end
if n < num
A = vcat(A, [zero(R) for i = 1:num - n])
end
return evaluate(data(a), A)
end

function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
end
if n < num
A = vcat(A, zeros(V, num - n))
for i in varidx
i <= n || error("Number of variables does not match number of values")
vals[i] = A[i]
end
return evaluate(data(a), A)
return evaluate(a2, vals)
end

function evaluate(a::UnivPoly{T}, A::Vector{V}) where {T <: RingElement, V <: RingElement}
n = length(A)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return evaluate(data(a), A[1:num])
end
if n < num
if n == 0
R = base_ring(a)
return evaluate(data(a), [zero(R) for _ in 1:num])
else
R = parent(A[1])
A = vcat(A, [zero(R) for _ in 1:num-n])
return evaluate(data(a), A)
end
end
return evaluate(data(a), A)
end

function (a::UnivPoly{T})() where {T <: RingElement}
return evaluate(a, T[])
end

function (a::UnivPoly{T})(vals::T...) where {T <: RingElement}
return evaluate(a, [vals...])
end

function (a::UnivPoly{T})(val::V, vals::V...) where {T <: RingElement, V <: Union{Integer, Rational, AbstractFloat}}
function (a::UnivPoly)(val::T, vals::T...) where T <: Union{NCRingElem, RingElement}
return evaluate(a, [val, vals...])
end

function (a::UnivPoly{T})(vals::Union{NCRingElem, RingElement}...) where {T <: RingElement}
A = [vals...]
n = length(vals)
num = nvars(parent(data(a)))
if n > num
n > nvars(parent(a)) && error("Too many values")
if nvars(parent(data(a))) == 0
return constant_coefficient(data(a))*one(parent(A[1]))
end
return data(a)(vals[1:num]...)
end
if n < num
A = vcat(A, zeros(Int, num - n))
end
return data(a)(A...)
end

function evaluate(a::UnivPoly{T}, vals::Vector{V}) where {T <: RingElement, V <: NCRingElem}
return a(vals...)
function (a::UnivPoly)()
return evaluate(a, Int[])
end

function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <: RingElement, V <: RingElement}
Expand All @@ -878,15 +811,14 @@ function evaluate(a::UnivPoly{T}, vars::Vector{Int}, vals::Vector{V}) where {T <
vals2 = Vector{mpoly_type(T)}(undef, 0)
num = nvars(parent(data(a)))
S = parent(a)
n = nvars(S)
a2 = data(S(a))
for i = 1:length(vars)
vars[i] > n && error("Unknown variable")
if vars[i] <= num
push!(vars2, vars[i])
push!(vals2, data(S(vals[i])))
end
end
return UnivPoly(evaluate(data(S(a)), vars2, vals2), S)
return UnivPoly(evaluate(a2, vars2, vals2), S)
end

function evaluate(a::S, vars::Vector{S}, vals::Vector{V}) where {S <: UnivPoly{T}, V <: RingElement} where {T <: RingElement}
Expand All @@ -900,7 +832,10 @@ function (a::Union{MPolyRingElem, UniversalPolyRingElem})(;kwargs...)
vals = Array{RingElement}(undef, length(kwargs))
for (i, (var, val)) in enumerate(kwargs)
vari = findfirst(isequal(var), ss)
vari === nothing && error("Given polynomial has no variable $var")
if vari === nothing
isa(a, MPolyRingElem) && error("Given polynomial has no variable $var")
continue
end
vars[i] = vari
vals[i] = val
end
Expand Down
7 changes: 0 additions & 7 deletions test/generic/UnivPoly-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -982,13 +982,6 @@ end
@test evaluate(g, V) == g([ZZ(v) for v in V]...)
@test evaluate(g, V) == g([U(v) for v in V]...)

@test evaluate(h, V) == evaluate(h, [R(v) for v in V])
@test evaluate(h, V) == evaluate(h, [ZZ(v) for v in V])
@test evaluate(h, V) == evaluate(h, [U(v) for v in V])
@test evaluate(h, V) == h(V...)
@test evaluate(h, V) == h([ZZ(v) for v in V]...)
@test evaluate(h, V) == h([U(v) for v in V]...)

V = [rand(-10:10) for v in 1:2]

@test evaluate(f, [1], [V[1]]) == evaluate(f, [1], [R(V[1])])
Expand Down
Loading