Improving the Efficiency of Gradient Descent Algorithms Applied to Optimization Problems with Dynamical Constraints
We introduce two block coordinate descent algorithms for solving optimization problems with ordinary differential equations (ODEs) as dynamical constraints. The algorithms do not need to implement direct or adjoint sensitivity analysis methods to evaluate loss function gradients. They results from r...
        Saved in:
      
    
          | Main Authors | , , , | 
|---|---|
| Format | Journal Article | 
| Language | English | 
| Published | 
          
        26.08.2022
     | 
| Subjects | |
| Online Access | Get full text | 
| DOI | 10.48550/arxiv.2208.12834 | 
Cover
| Summary: | We introduce two block coordinate descent algorithms for solving optimization
problems with ordinary differential equations (ODEs) as dynamical constraints.
The algorithms do not need to implement direct or adjoint sensitivity analysis
methods to evaluate loss function gradients. They results from reformulation of
the original problem as an equivalent optimization problem with equality
constraints. The algorithms naturally follow from steps aimed at recovering the
gradient-decent algorithm based on ODE solvers that explicitly account for
sensitivity of the ODE solution. In our first proposed algorithm we avoid
explicitly solving the ODE by integrating the ODE solver as a sequence of
implicit constraints. In our second algorithm, we use an ODE solver to reset
the ODE solution, but no direct are adjoint sensitivity analysis methods are
used. Both algorithm accepts mini-batch implementations and show significant
efficiency benefits from GPU-based parallelization. We demonstrate the
performance of the algorithms when applied to learning the parameters of the
Cucker-Smale model. The algorithms are compared with gradient descent
algorithms based on ODE solvers endowed with sensitivity analysis capabilities,
for various number of state size, using Pytorch and Jax implementations. The
experimental results demonstrate that the proposed algorithms are at least 4x
faster than the Pytorch implementations, and at least 16x faster than Jax
implementations. For large versions of the Cucker-Smale model, the Jax
implementation is thousands of times faster than the sensitivity analysis-based
implementation. In addition, our algorithms generate more accurate results both
on training and test data. Such gains in computational efficiency is paramount
for algorithms that implement real time parameter estimations, such as
diagnosis algorithms. | 
|---|---|
| DOI: | 10.48550/arxiv.2208.12834 |