GPU Support
In this example, contributed by Zack Li, we solve the Lasso problem with GeNIOS' GenericInterface
and a custom HessianOperator
which uses an NVIDIA GPU via CUDA.jl.
using CUDA
using GeNIOS
using Random, LinearAlgebra, SparseArrays
# the existing code in the generic interface tutorial
Random.seed!(1)
m, n = 5000, 10000
A = randn(m, n)
A .-= sum(A, dims=1) ./ m
normalize!.(eachcol(A))
xstar = sprandn(n, 0.1)
b = A*xstar + 1e-3*randn(m)
λ = 0.05*norm(A'*b, Inf)
solver_c = zeros(n)
tmp_arr = zeros(m)
struct HessianLasso{T, S <: AbstractMatrix{T}, V <: AbstractVector{T}} <: HessianOperator
A::S
vm::V
end
function LinearAlgebra.mul!(y, H::HessianLasso, x)
mul!(H.vm, H.A, x)
mul!(y, H.A', H.vm)
return nothing
end
function update!(::HessianLasso, ::Solver)
return nothing
end
# **********************************************************
# choose between Array or CuArray here, just comment one out
# **********************************************************
T = Float32
AT = CuArray{T}
# AT = Array{T}
A = AT(A)
b = AT(b)
λ = convert(eltype(A), λ)
solver_c = AT(solver_c)
tmp_arr = AT(tmp_arr)
params = (; A=A, b=b, tmp=tmp_arr, λ=λ)
function f(x, p)
A, b, tmp = p.A, p.b, p.tmp
mul!(tmp, A, x)
@. tmp -= b
return sum(w->w^2, tmp) / 2
end
function grad_f!(g, x, p)
A, b, tmp = p.A, p.b, p.tmp
mul!(tmp, A, x)
@. tmp -= b
mul!(g, A', tmp)
return nothing
end
# **********************************************************
# NOTE: needed to convert the cache vector to AT
# **********************************************************
Hf = HessianLasso(A, AT(zeros(m)))
g(z, p) = p.λ*sum(x->abs(x), z)
function prox_g!(v, z, ρ, p)
λ = p.λ
@inline soft_threshold(x::T, κ::T) where {T <: Real} = sign(x) * max(zero(T), abs(x) - κ)
v .= soft_threshold.(z, λ/ρ)
end
solver = GeNIOS.GenericSolver(
f, grad_f!, Hf, # f(x)
g, prox_g!, # g(z)
I, solver_c; # M, c: Mx + z = c
params=params
)
res = solve!(solver; options=GeNIOS.SolverOptions{T,T}(
relax=true, verbose=true, precondition=false, update_preconditioner=false))
rmse = sqrt(1/m*norm(A*solver.zk - b, 2)^2)
println("Final RMSE: $(round(rmse, digits=8))")