!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module propagation_oct_m
  use absorbing_boundaries_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use controlfunction_oct_m
  use debug_oct_m
  use density_oct_m
  use derivatives_oct_m
  use electrons_oct_m
  use electron_space_oct_m
  use energy_calc_oct_m
  use epot_oct_m
  use excited_states_oct_m
  use ext_partner_list_oct_m
  use forces_oct_m
  use gauge_field_oct_m
  use global_oct_m
  use grid_oct_m
  use hamiltonian_elec_oct_m
  use interaction_partner_oct_m
  use ion_dynamics_oct_m
  use ions_oct_m
  use, intrinsic :: iso_fortran_env
  use kpoints_oct_m
  use ks_potential_oct_m
  use lasers_oct_m
  use lalg_basic_oct_m
  use loct_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use messages_oct_m
  use mpi_oct_m
  use multicomm_oct_m
  use namespace_oct_m
  use oct_exchange_oct_m
  use opt_control_state_oct_m
  use parser_oct_m
  use perturbation_ionic_oct_m
  use potential_interpolation_oct_m
  use propagator_elec_oct_m
  use propagator_base_oct_m
  use profiling_oct_m
  use restart_oct_m
  use space_oct_m
  use states_elec_oct_m
  use states_elec_dim_oct_m
  use states_elec_restart_oct_m
  use target_oct_m
  use target_low_oct_m
  use td_oct_m
  use td_write_oct_m
  use v_ks_oct_m

  implicit none

  private

  public ::               &
    propagation_mod_init, &
    propagate_forward,    &
    propagate_backward,   &
    fwd_step,             &
    bwd_step,             &
    bwd_step_2,           &
    oct_prop_t,           &
    oct_prop_init,        &
    oct_prop_check,       &
    oct_prop_end


  type oct_prop_t
    private
    integer :: number_checkpoints
    integer, allocatable :: iter(:)
    integer :: niter
    character(len=100) :: dirname
    type(restart_t) :: restart_load
    type(restart_t) :: restart_dump
  end type oct_prop_t


  !> Module variables
  integer :: niter_
  integer :: number_checkpoints_
  real(real64)   :: eta_
  real(real64)   :: delta_
  logical :: zbr98_
  logical :: gradients_

contains

  !> This subroutine must be called before any QOCT propagations are
  !! done. It simply stores in the module some data that is needed for
  !! the propagations, and which should stay invariant during the whole
  !! run.
  !! There is no need for any propagation_mod_close.
  subroutine propagation_mod_init(niter, eta, delta, number_checkpoints, zbr98, gradients)
    integer, intent(in) :: niter
    real(real64),   intent(in) :: eta
    real(real64),   intent(in) :: delta
    integer, intent(in) :: number_checkpoints
    logical, intent(in) :: zbr98
    logical, intent(in) :: gradients

    ASSERT(.not. (zbr98 .and. gradients))

    PUSH_SUB(propagation_mod_init)

    niter_              = niter
    eta_                = eta
    delta_              = delta
    number_checkpoints_ = number_checkpoints
    zbr98_              = zbr98
    gradients_          = gradients

    POP_SUB(propagation_mod_init)
  end subroutine propagation_mod_init
  ! ---------------------------------------------------------


  !> ---------------------------------------------------------
  !! Performs a full propagation of state psi, with the laser
  !! field specified in par. If write_iter is present and is
  !! set to .true., writes down through the td_write module.
  !! ---------------------------------------------------------
  subroutine propagate_forward(sys, td, par, tg, qcpsi, prop, write_iter)
    type(electrons_t),          intent(inout)  :: sys
    type(td_t),                 intent(inout)  :: td
    type(controlfunction_t),    intent(in)     :: par
    type(target_t),             intent(inout)  :: tg
    type(opt_control_state_t),  intent(inout)  :: qcpsi
    type(oct_prop_t), optional, intent(inout)  :: prop
    logical, optional,          intent(in)     :: write_iter

    integer :: ii, istep, ierr
    logical :: write_iter_ = .false.
    type(td_write_t)           :: write_handler
    real(real64), allocatable :: x_initial(:,:)
    logical :: vel_target_ = .false.
    type(states_elec_t), pointer :: psi

    real(real64) :: init_time, final_time

    PUSH_SUB(propagate_forward)

    message(1) = "Info: Forward propagation."
    call messages_info(1, namespace=sys%namespace)

    call controlfunction_to_h(par, sys%ext_partners)

    write_iter_ = .false.
    if (present(write_iter)) write_iter_ = write_iter

    psi => opt_control_point_qs(qcpsi)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%pack()
    call opt_control_get_classical(sys%ions, qcpsi)

    if (write_iter_) then
      call td_write_init(write_handler, sys%namespace, sys%space, sys%outp, sys%gr, sys%st, sys%hm, &
        sys%ions, sys%ext_partners, sys%ks, td%ions_dyn%ions_move(), &
        list_has_gauge_field(sys%ext_partners), sys%hm%kick, td%iter, td%max_iter, &
        td%dt, sys%mc)
      call td_write_data(write_handler)
    end if

    call hamiltonian_elec_not_adjoint(sys%hm)

    ! setup the Hamiltonian
    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners, time = M_ZERO)
    call sys%hm%ks_pot%run_zero_iter(td%tr%vks_old)
    if (td%ions_dyn%ions_move()) then
      call hamiltonian_elec_epot_generate(sys%hm, sys%namespace,  sys%space, sys%gr, sys%ions, &
        sys%ext_partners, psi, time = M_ZERO)
      call forces_calculate(sys%gr, sys%namespace, sys%ions, sys%hm, sys%ext_partners, psi, sys%ks, t = M_ZERO, dt = td%dt)
    end if


    if (target_type(tg) == oct_tg_hhgnew) then
      call target_init_propagation(tg)
    end if

    if (target_type(tg) == oct_tg_velocity .or. target_type(tg) == oct_tg_hhgnew) then
      SAFE_ALLOCATE_SOURCE_A(x_initial, sys%ions%pos)
      vel_target_ = .true.
      sys%ions%vel = M_ZERO
      sys%ions%tot_force = M_ZERO
    end if

    if (.not. target_move_ions(tg)) call epot_precalc_local_potential(sys%hm%ep, sys%namespace, sys%gr, sys%ions)

    call target_tdcalc(tg, sys%namespace, sys%space, sys%hm, sys%gr, sys%ions, sys%ext_partners, psi, 0, td%max_iter)

    if (present(prop)) then
      call oct_prop_dump_states(prop, sys%space, 0, psi, sys%gr, sys%kpoints, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to write OCT states restart."
        call messages_warning(1, namespace=sys%namespace)
      end if
    end if

    init_time = loct_clock()
    if (mpi_grp_is_root(mpi_world)) call loct_progress_bar(-1, td%max_iter)

    ii = 1
    do istep = 1, td%max_iter
      ! time-iterate wavefunctions

      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, istep*td%dt, td%dt, istep, &
        td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)

      if (present(prop)) then
        call oct_prop_dump_states(prop, sys%space, istep, psi, sys%gr, sys%kpoints, ierr)
        if (ierr /= 0) then
          message(1) = "Unable to write OCT states restart."
          call messages_warning(1, namespace=sys%namespace)
        end if
      end if

      ! update
      call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners, time = istep*td%dt)
      call energy_calc_total(sys%namespace, sys%space, sys%hm, sys%gr, psi, sys%ext_partners)

      if (sys%hm%abs_boundaries%abtype == MASK_ABSORBING) call zvmask(sys%gr, sys%hm, psi)

      ! if td_target
      call target_tdcalc(tg, sys%namespace, sys%space, sys%hm, sys%gr, sys%ions, sys%ext_partners, psi, istep, td%max_iter)

      ! only write in final run
      if (write_iter_) then
        call td_write_iter(write_handler, sys%namespace, sys%space, sys%outp, sys%gr, psi, sys%hm,  sys%ions, &
          sys%ext_partners, sys%hm%kick, sys%ks, td%dt, istep, sys%mc, sys%td%recalculate_gs)
        ii = ii + 1
        if (any(ii == sys%outp%output_interval + 1) .or. istep == td%max_iter) then ! output
          if (istep == td%max_iter) sys%outp%output_interval = ii - 1
          ii = istep
          call td_write_data(write_handler)
        end if
      end if

      if ((mod(istep, 100) == 0) .and. mpi_grp_is_root(mpi_world)) call loct_progress_bar(istep, td%max_iter)
    end do
    if (mpi_grp_is_root(mpi_world)) write(stdout, '(1x)')

    final_time = loct_clock()
    write(message(1),'(a,f12.2,a)') 'Propagation time: ', final_time - init_time, ' seconds.'
    call messages_info(1, namespace=sys%namespace)

    if (vel_target_) then
      sys%ions%pos = x_initial
      SAFE_DEALLOCATE_A(x_initial)
    end if

    call opt_control_set_classical(sys%ions, qcpsi)

    if (write_iter_) call td_write_end(write_handler)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%unpack()
    nullify(psi)
    POP_SUB(propagate_forward)
  end subroutine propagate_forward
  ! ---------------------------------------------------------


  !> ---------------------------------------------------------
  !! Performs a full backward propagation of state psi, with the
  !! external fields specified in Hamiltonian h.
  !! ---------------------------------------------------------
  subroutine propagate_backward(sys, td, qcpsi, prop)
    type(electrons_t),         intent(inout) :: sys
    type(td_t),                intent(inout) :: td
    type(opt_control_state_t), intent(inout) :: qcpsi
    type(oct_prop_t),          intent(inout) :: prop

    integer :: istep, ierr
    type(states_elec_t), pointer :: psi

    PUSH_SUB(propagate_backward)

    message(1) = "Info: Backward propagation."
    call messages_info(1, namespace=sys%namespace)

    psi => opt_control_point_qs(qcpsi)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%pack()

    call hamiltonian_elec_adjoint(sys%hm)

    ! setup the Hamiltonian
    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%ks_pot%run_zero_iter(td%tr%vks_old)

    call oct_prop_dump_states(prop, sys%space, td%max_iter, psi, sys%gr, sys%kpoints, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to write OCT states restart."
      call messages_warning(1, namespace=sys%namespace)
    end if

    if (mpi_grp_is_root(mpi_world)) call loct_progress_bar(-1, td%max_iter)

    do istep = td%max_iter, 1, -1
      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, &
        (istep - 1)*td%dt, -td%dt, istep-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)

      call oct_prop_dump_states(prop, sys%space, istep - 1, psi, sys%gr, sys%kpoints, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to write OCT states restart."
        call messages_warning(1, namespace=sys%namespace)
      end if

      call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
      if (mod(istep, 100) == 0 .and. mpi_grp_is_root(mpi_world)) call loct_progress_bar(td%max_iter - istep + 1, td%max_iter)
    end do

    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%unpack()
    nullify(psi)
    POP_SUB(propagate_backward)
  end subroutine propagate_backward
  ! ---------------------------------------------------------


  !> ---------------------------------------------------------
  !! Performs a forward propagation on the state psi and on the
  !! Lagrange-multiplier state chi. It also updates the control
  !! function par,  according to the following scheme:
  !!
  !! |chi> --> U[par_chi](T, 0)|chi>
  !! par = par[|psi>, |chi>]
  !! |psi> --> U[par](T, 0)|psi>
  !!
  !! Note that the control functions "par" are updated on the
  !! fly, so that the propagation of psi is performed with the
  !! "new" control functions.
  !! --------------------------------------------------------
  subroutine fwd_step(sys, td, tg, par, par_chi, qcpsi, prop_chi, prop_psi)
    type(electrons_t),         intent(inout) :: sys
    type(td_t),                intent(inout) :: td
    type(target_t),            intent(inout) :: tg
    type(controlfunction_t),   intent(inout) :: par
    type(controlfunction_t),   intent(in)    :: par_chi
    type(opt_control_state_t), intent(inout) :: qcpsi
    type(oct_prop_t),          intent(inout) :: prop_chi
    type(oct_prop_t),          intent(inout) :: prop_psi

    integer :: i, ierr
    logical :: aux_fwd_propagation
    type(states_elec_t) :: psi2
    type(opt_control_state_t) :: qcchi
    type(controlfunction_t) :: par_prev
    type(propagator_base_t) :: tr_chi, tr_psi2
    type(states_elec_t), pointer :: psi, chi

    PUSH_SUB(fwd_step)

    message(1) = "Info: Forward propagation."
    call messages_info(1, namespace=sys%namespace)

    call controlfunction_to_realtime(par)

    call opt_control_state_null(qcchi)
    call opt_control_state_copy(qcchi, qcpsi)

    psi => opt_control_point_qs(qcpsi)
    chi => opt_control_point_qs(qcchi)
    call propagator_elec_copy(tr_chi, td%tr)
    ! The propagation of chi should not be self-consistent, because the Kohn-Sham
    ! potential used is the one created by psi. Note, however, that it is likely that
    ! the first two iterations are done self-consistently nonetheless.
    call propagator_elec_remove_scf_prop(tr_chi)

    aux_fwd_propagation = (target_mode(tg) == oct_targetmode_td .or. &
      (sys%hm%theory_level /= INDEPENDENT_PARTICLES .and. &
      .not. sys%ks%frozen_hxc))
    if (aux_fwd_propagation) then
      call states_elec_copy(psi2, psi)
      call controlfunction_copy(par_prev, par)
      if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi2%pack()
    end if

    ! setup forward propagation
    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%ks_pot%run_zero_iter(td%tr%vks_old)
    call sys%hm%ks_pot%run_zero_iter(tr_chi%vks_old)
    if (aux_fwd_propagation) then
      call propagator_elec_copy(tr_psi2, td%tr)
      call sys%hm%ks_pot%run_zero_iter(tr_psi2%vks_old)
    end if

    call oct_prop_dump_states(prop_psi, sys%space, 0, psi, sys%gr, sys%kpoints, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to write OCT states restart."
      call messages_warning(1, namespace=sys%namespace)
    end if
    call oct_prop_load_states(prop_chi, sys%namespace, sys%space, chi, sys%gr, sys%kpoints, 0, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to read OCT states restart."
      call messages_fatal(1, namespace=sys%namespace)
    end if

    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%pack()
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call chi%pack()

    do i = 1, td%max_iter
      call update_field(i, par, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par_chi, dir = 'f')
      call update_hamiltonian_elec_chi(i, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
        td, tg, par_chi, sys%ions, psi2)
      call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners, time = (i - 1)*td%dt)
      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, chi, tr_chi, i*td%dt, td%dt, i, &
        td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
      if (aux_fwd_propagation) then
        call update_hamiltonian_elec_psi(i, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
          td, tg, par_prev, psi2, sys%ions)
        call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi2, tr_psi2, i*td%dt, td%dt, i, &
          td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
      end if
      call update_hamiltonian_elec_psi(i, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, &
        sys%ext_partners, td, tg, par, psi, sys%ions)
      call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners, time = (i - 1)*td%dt)
      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, i*td%dt, td%dt, i, &
        td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
      call target_tdcalc(tg, sys%namespace, sys%space, sys%hm, sys%gr, sys%ions, sys%ext_partners, psi, i, td%max_iter)

      call oct_prop_dump_states(prop_psi, sys%space, i, psi, sys%gr, sys%kpoints, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to write OCT states restart."
        call messages_warning(1, namespace=sys%namespace)
      end if
      call oct_prop_check(prop_chi, sys%namespace, sys%space, chi, sys%gr, sys%kpoints, i)
    end do
    call update_field(td%max_iter+1, par, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par_chi, dir = 'f')

    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)

    if (target_mode(tg) == oct_targetmode_td .or. &
      (sys%hm%theory_level /= INDEPENDENT_PARTICLES .and. (.not. sys%ks%frozen_hxc))) then
      call states_elec_end(psi2)
      call controlfunction_end(par_prev)
    end if

    call controlfunction_to_basis(par)
    if (aux_fwd_propagation) call propagator_elec_end(tr_psi2)
    call states_elec_end(chi)
    call propagator_elec_end(tr_chi)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%unpack()
    nullify(psi)
    nullify(chi)
    POP_SUB(fwd_step)
  end subroutine fwd_step
  ! ---------------------------------------------------------


  !> --------------------------------------------------------
  !! Performs a backward propagation on the state psi and on the
  !! Lagrange-multiplier state chi, according to the following
  !! scheme:
  !!
  !! |psi> --> U[par](0, T)|psi>
  !! par_chi = par_chi[|psi>, |chi>]
  !! |chi> --> U[par_chi](0, T)|chi>
  !! --------------------------------------------------------
  subroutine bwd_step(sys, td, tg, par, par_chi, qcchi, prop_chi, prop_psi)
    type(electrons_t),         intent(inout) :: sys
    type(td_t),                intent(inout) :: td
    type(target_t),            intent(inout) :: tg
    type(controlfunction_t),   intent(in)    :: par
    type(controlfunction_t),   intent(inout) :: par_chi
    type(opt_control_state_t), intent(inout) :: qcchi
    type(oct_prop_t),          intent(inout) :: prop_chi
    type(oct_prop_t),          intent(inout) :: prop_psi

    integer :: i, ierr
    type(propagator_base_t) :: tr_chi
    type(opt_control_state_t) :: qcpsi
    type(states_elec_t), pointer :: chi, psi

    PUSH_SUB(bwd_step)

    message(1) = "Info: Backward propagation."
    call messages_info(1, namespace=sys%namespace)

    call controlfunction_to_realtime(par_chi)

    chi => opt_control_point_qs(qcchi)
    psi => opt_control_point_qs(qcpsi)

    call propagator_elec_copy(tr_chi, td%tr)
    ! The propagation of chi should not be self-consistent, because the Kohn-Sham
    ! potential used is the one created by psi. Note, however, that it is likely that
    ! the first two iterations are done self-consistently nonetheless.
    call propagator_elec_remove_scf_prop(tr_chi)

    call states_elec_copy(psi, chi)
    call oct_prop_load_states(prop_psi, sys%namespace, sys%space, psi, sys%gr, sys%kpoints, td%max_iter, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to read OCT states restart."
      call messages_fatal(1, namespace=sys%namespace)
    end if
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%pack()
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call chi%pack()

    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners)
    call sys%hm%ks_pot%run_zero_iter(td%tr%vks_old)
    call sys%hm%ks_pot%run_zero_iter(tr_chi%vks_old)

    td%dt = -td%dt
    call oct_prop_dump_states(prop_chi, sys%space, td%max_iter, chi, sys%gr, sys%kpoints, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to write OCT states restart."
      call messages_warning(1, namespace=sys%namespace)
    end if

    do i = td%max_iter, 1, -1
      call oct_prop_check(prop_psi, sys%namespace, sys%space, psi, sys%gr, sys%kpoints, i)
      call update_field(i, par_chi, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par, dir = 'b')
      call update_hamiltonian_elec_chi(i-1, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
        td, tg, par_chi, sys%ions, psi)
      call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners, time = abs(i*td%dt))
      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, chi, tr_chi, abs((i-1)*td%dt), td%dt, &
        i-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
      call oct_prop_dump_states(prop_chi, sys%space, i-1, chi, sys%gr, sys%kpoints, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to write OCT states restart."
        call messages_warning(1, namespace=sys%namespace)
      end if
      call update_hamiltonian_elec_psi(i-1, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
        td, tg, par, psi, sys%ions)
      call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners, time = abs(i*td%dt))
      call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, abs((i-1)*td%dt), td%dt, &
        i-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
    end do
    td%dt = -td%dt
    call update_field(0, par_chi, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par, dir = 'b')

    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners)

    call controlfunction_to_basis(par_chi)
    call states_elec_end(psi)
    call propagator_elec_end(tr_chi)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call chi%unpack()
    nullify(chi)
    nullify(psi)
    POP_SUB(bwd_step)
  end subroutine bwd_step
  ! ---------------------------------------------------------


  !> --------------------------------------------------------
  !! Performs a backward propagation on the state psi and on the
  !! Lagrange-multiplier state chi, according to the following
  !! scheme:
  !!
  !! |psi> --> U[par](0, T)|psi>
  !! |chi> --> U[par](0, T)|chi>
  !!
  !! It also calculates during the propagation, a new "output" field:
  !!
  !! par_chi = par_chi[|psi>, |chi>]
  !! --------------------------------------------------------
  subroutine bwd_step_2(sys, td, tg, par, par_chi, qcchi, prop_chi, prop_psi)
    type(electrons_t),                 intent(inout) :: sys
    type(td_t),                        intent(inout) :: td
    type(target_t),                    intent(inout) :: tg
    type(controlfunction_t),           intent(in)    :: par
    type(controlfunction_t),           intent(inout) :: par_chi
    type(opt_control_state_t),         intent(inout) :: qcchi
    type(oct_prop_t),                  intent(inout) :: prop_chi
    type(oct_prop_t),                  intent(inout) :: prop_psi

    integer :: i, ierr, ik, ib
    logical :: freeze
    type(propagator_base_t) :: tr_chi
    type(opt_control_state_t) :: qcpsi
    type(states_elec_t) :: st_ref
    type(states_elec_t), pointer :: chi, psi
    real(real64), pointer :: q(:, :), p(:, :)
    real(real64), allocatable :: qtildehalf(:, :), qinitial(:, :)
    real(real64), allocatable :: vhxc(:, :)
    real(real64), allocatable :: fold(:, :), fnew(:, :)
    type(ion_state_t) :: ions_state_initial, ions_state_final

    real(real64) :: init_time, final_time

    PUSH_SUB(bwd_step_2)

    chi => opt_control_point_qs(qcchi)
    q => opt_control_point_q(qcchi)
    p => opt_control_point_p(qcchi)
    SAFE_ALLOCATE(qtildehalf(1:sys%ions%natoms, 1:sys%ions%space%dim))
    SAFE_ALLOCATE(qinitial(1:sys%ions%space%dim, 1:sys%ions%natoms))

    call propagator_elec_copy(tr_chi, td%tr)
    ! The propagation of chi should not be self-consistent, because the Kohn-Sham
    ! potential used is the one created by psi. Note, however, that it is likely that
    ! the first two iterations are done self-consistently nonetheless.
    call propagator_elec_remove_scf_prop(tr_chi)

    call opt_control_state_null(qcpsi)
    call opt_control_state_copy(qcpsi, qcchi)
    psi => opt_control_point_qs(qcpsi)
    call oct_prop_load_states(prop_psi, sys%namespace, sys%space, psi, sys%gr, sys%kpoints, td%max_iter, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to read OCT states restart."
      call messages_fatal(1, namespace=sys%namespace)
    end if

    SAFE_ALLOCATE(vhxc(1:sys%gr%np, 1:sys%hm%d%nspin))

    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners)
    call sys%hm%ks_pot%run_zero_iter(td%tr%vks_old)
    call sys%hm%ks_pot%run_zero_iter(tr_chi%vks_old)
    td%dt = -td%dt
    call oct_prop_dump_states(prop_chi, sys%space, td%max_iter, chi, sys%gr, sys%kpoints, ierr)
    if (ierr /= 0) then
      message(1) = "Unable to write OCT states restart."
      call messages_warning(1, namespace=sys%namespace)
    end if

    call states_elec_copy(st_ref, psi)
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call psi%pack()
    if (sys%st%pack_states .and. sys%hm%apply_packed()) call st_ref%pack()

    if (td%ions_dyn%ions_move()) then
      call forces_calculate(sys%gr, sys%namespace, sys%ions, sys%hm, sys%ext_partners, &
        psi, sys%ks, t = td%max_iter*abs(td%dt), dt = td%dt)
    end if

    message(1) = "Info: Backward propagation."
    call messages_info(1, namespace=sys%namespace)
    if (mpi_grp_is_root(mpi_world)) call loct_progress_bar(-1, td%max_iter)

    init_time = loct_clock()

    do i = td%max_iter, 1, -1

      call oct_prop_check(prop_psi, sys%namespace, sys%space, psi, sys%gr, sys%kpoints, i)
      call update_field(i, par_chi, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par, dir = 'b')

      select case (td%tr%method)

      case (PROP_EXPLICIT_RUNGE_KUTTA4)

        call update_hamiltonian_elec_psi(i-1, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
          td, tg, par, psi, sys%ions)
        call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, abs((i-1)*td%dt), td%dt, &
          i-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler, qcchi = qcchi)

      case default

        if (td%ions_dyn%ions_move()) then
          qtildehalf = q

          call ion_dynamics_save_state(td%ions_dyn, sys%ions, ions_state_initial)
          call ion_dynamics_propagate(td%ions_dyn, sys%ions, abs((i-1)*td%dt), M_HALF * td%dt, sys%namespace)
          qinitial = sys%ions%pos
          call ion_dynamics_restore_state(td%ions_dyn, sys%ions, ions_state_initial)

          SAFE_ALLOCATE(fold(1:sys%ions%natoms, 1:sys%ions%space%dim))
          SAFE_ALLOCATE(fnew(1:sys%ions%natoms, 1:sys%ions%space%dim))
          call forces_costate_calculate(sys%gr, sys%namespace, sys%ions, sys%hm, psi, chi, fold, q)

          call ion_dynamics_verlet_step1(sys%ions, qtildehalf, p, fold, M_HALF * td%dt)
          call ion_dynamics_verlet_step1(sys%ions, q, p, fold, td%dt)
        end if

        if (td%ions_dyn%cell_relax()) then
          call messages_not_implemented("OCT with cell dynamics")
        end if

        ! Here propagate psi one full step, and then simply interpolate to get the state
        ! at half the time interval. Perhaps one could gain some accuracy by performing two
        ! successive propagations of half time step.
        call update_hamiltonian_elec_psi(i-1, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
          td, tg, par, psi, sys%ions)

        do ik = psi%d%kpt%start, psi%d%kpt%end
          do ib = psi%group%block_start, psi%group%block_end
            call psi%group%psib(ib, ik)%copy_data_to(sys%gr%np, st_ref%group%psib(ib, ik))
          end do
        end do

        vhxc(:, :) = sys%hm%ks_pot%vhxc(:, :)
        call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, psi, td%tr, abs((i-1)*td%dt), td%dt, &
          i-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)

        if (td%ions_dyn%ions_move()) then
          call ion_dynamics_save_state(td%ions_dyn, sys%ions, ions_state_final)
          sys%ions%pos = qinitial
          call hamiltonian_elec_epot_generate(sys%hm, sys%namespace, sys%space, sys%gr, sys%ions, &
            sys%ext_partners, psi, time = abs((i-1)*td%dt))
        end if

        do ik = psi%d%kpt%start, psi%d%kpt%end
          do ib = psi%group%block_start, psi%group%block_end
            call batch_scal(sys%gr%np, cmplx(M_HALF, M_ZERO, real64) , &
              st_ref%group%psib(ib, ik))
            call batch_axpy(sys%gr%np, cmplx(M_HALF, M_ZERO, real64) , &
              psi%group%psib(ib, ik), st_ref%group%psib(ib, ik))
          end do
        end do

        sys%hm%ks_pot%vhxc(:, :) = M_HALF * (sys%hm%ks_pot%vhxc(:, :) + vhxc(:, :))
        call update_hamiltonian_elec_chi(i-1, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
          td, tg, par, sys%ions, st_ref, qtildehalf)
        freeze = ion_dynamics_freeze(td%ions_dyn)
        call propagator_elec_dt(sys%ks, sys%namespace, sys%space, sys%hm, sys%gr, chi, tr_chi, abs((i-1)*td%dt), td%dt, &
          i-1, td%ions_dyn, sys%ions, sys%ext_partners, sys%mc, sys%outp, td%write_handler)
        if (freeze) call ion_dynamics_unfreeze(td%ions_dyn)

        if (td%ions_dyn%ions_move()) then
          call ion_dynamics_restore_state(td%ions_dyn, sys%ions, ions_state_final)
          call hamiltonian_elec_epot_generate(sys%hm, sys%namespace, sys%space, sys%gr, sys%ions, &
            sys%ext_partners, psi, time = abs((i-1)*td%dt))
          call forces_calculate(sys%gr, sys%namespace, sys%ions, sys%hm, sys%ext_partners, &
            psi, sys%ks, t = abs((i-1)*td%dt), dt = td%dt)
          call forces_costate_calculate(sys%gr, sys%namespace, sys%ions, sys%hm, psi, chi, fnew, q)
          call ion_dynamics_verlet_step2(sys%ions, p, fold, fnew, td%dt)
          SAFE_DEALLOCATE_A(fold)
          SAFE_DEALLOCATE_A(fnew)
        end if

        sys%hm%ks_pot%vhxc(:, :) = vhxc(:, :)

      end select

      call oct_prop_dump_states(prop_chi, sys%space, i-1, chi, sys%gr, sys%kpoints, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to write OCT states restart."
        call messages_warning(1, namespace=sys%namespace)
      end if

      if ((mod(i, 100) == 0).and. mpi_grp_is_root(mpi_world)) call loct_progress_bar(td%max_iter-i, td%max_iter)
    end do
    if (mpi_grp_is_root(mpi_world)) then
      call loct_progress_bar(td%max_iter, td%max_iter)
      write(stdout, '(1x)')
    end if

    final_time = loct_clock()
    write(message(1),'(a,f12.2,a)') 'Propagation time: ', final_time - init_time, ' seconds.'
    call messages_info(1, namespace=sys%namespace)

    call states_elec_end(st_ref)

    td%dt = -td%dt
    call update_hamiltonian_elec_psi(0, sys%namespace, sys%space, sys%gr, sys%ks, sys%hm, sys%ext_partners, &
      td, tg, par, psi, sys%ions)
    call update_field(0, par_chi, sys%space, sys%gr, sys%hm, sys%ext_partners, sys%ions, qcpsi, qcchi, par, dir = 'b')

    call density_calc(psi, sys%gr, psi%rho)
    call v_ks_calc(sys%ks, sys%namespace, sys%space, sys%hm, psi, sys%ions, sys%ext_partners)
    call sys%hm%update(sys%gr, sys%namespace, sys%space, sys%ext_partners)

    call propagator_elec_end(tr_chi)

    SAFE_DEALLOCATE_A(vhxc)
    call states_elec_end(psi)

    nullify(chi)
    nullify(psi)
    nullify(q)
    nullify(p)
    SAFE_DEALLOCATE_A(qtildehalf)
    SAFE_DEALLOCATE_A(qinitial)
    POP_SUB(bwd_step_2)
  end subroutine bwd_step_2
  ! ----------------------------------------------------------


  ! ----------------------------------------------------------
  !
  ! ----------------------------------------------------------
  subroutine update_hamiltonian_elec_chi(iter, namespace, space, gr, ks, hm, ext_partners, td, tg, par_chi, ions, st, qtildehalf)
    integer,                  intent(in)    :: iter
    type(namespace_t),        intent(in)    :: namespace
    class(space_t),           intent(in)    :: space
    type(grid_t),             intent(inout) :: gr
    type(v_ks_t),             intent(inout) :: ks
    type(hamiltonian_elec_t), intent(inout) :: hm
    type(partner_list_t),     intent(in)    :: ext_partners
    type(td_t),               intent(inout) :: td
    type(target_t),           intent(inout) :: tg
    type(controlfunction_t),  intent(in)    :: par_chi
    type(ions_t),             intent(in)    :: ions
    type(states_elec_t),      intent(inout) :: st
    real(real64),   optional, intent(in)    :: qtildehalf(:, :)

    type(states_elec_t) :: inh
    type(perturbation_ionic_t), pointer :: pert
    integer :: j, iatom, idim
    complex(real64), allocatable :: dvpsi(:, :, :), zpsi(:, :), inhzpsi(:, :)
    integer :: ist, ik, ib

    PUSH_SUB(update_hamiltonian_elec_chi)

    ASSERT(.not. st%parallel_in_states)
    ASSERT(.not. st%d%kpt%parallel)

    if (target_mode(tg) == oct_targetmode_td) then
      call states_elec_copy(inh, st)
      call target_inh(st, gr, hm%kpoints, tg, abs(td%dt)*iter, inh, iter)
      call hamiltonian_elec_set_inh(hm, inh)
      call states_elec_end(inh)
    end if

    if (td%ions_dyn%ions_move()) then
      call states_elec_copy(inh, st)
      SAFE_ALLOCATE(dvpsi(1:gr%np_part, 1:st%d%dim, 1:space%dim))
      do ik = inh%d%kpt%start, inh%d%kpt%end
        do ib = inh%group%block_start, inh%group%block_end
          call batch_set_zero(inh%group%psib(ib, ik))
        end do
      end do
      SAFE_ALLOCATE(zpsi(1:gr%np_part, 1:st%d%dim))
      SAFE_ALLOCATE(inhzpsi(1:gr%np_part, 1:st%d%dim))
      do ist = 1, st%nst
        do ik = 1, st%nik

          call states_elec_get_state(st, gr, ist, ik, zpsi)
          call states_elec_get_state(inh, gr, ist, ik, inhzpsi)

          pert => perturbation_ionic_t(namespace, ions)
          do iatom = 1, ions%natoms
            call pert%setup_atom(iatom)
            do idim = 1, space%dim
              call pert%setup_dir(idim)
              call pert%zapply(namespace, space, gr, hm, ik, zpsi(:, :), dvpsi(:, :, idim))

              call lalg_axpy(gr%np, st%d%dim, -st%occ(ist, ik)*qtildehalf(iatom, idim), &
                dvpsi(:, :, idim), inhzpsi(:,  :))
            end do
          end do
          SAFE_DEALLOCATE_P(pert)
          call states_elec_set_state(inh, gr, ist, ik, inhzpsi)
        end do
      end do

      SAFE_DEALLOCATE_A(zpsi)
      SAFE_DEALLOCATE_A(inhzpsi)
      SAFE_DEALLOCATE_A(dvpsi)
      call hamiltonian_elec_set_inh(hm, inh)
      call states_elec_end(inh)
    end if

    if (hm%theory_level /= INDEPENDENT_PARTICLES .and. (.not. ks%frozen_hxc)) then
      call density_calc(st, gr, st%rho)
      call oct_exchange_set(hm%oct_exchange, st, gr)
    end if

    call hamiltonian_elec_adjoint(hm)

    do j = iter - 2, iter + 2
      if (j >= 0 .and. j <= td%max_iter) then
        call controlfunction_to_h_val(par_chi, ext_partners, j+1)
      end if
    end do

    POP_SUB(update_hamiltonian_elec_chi)
  end subroutine update_hamiltonian_elec_chi
  ! ---------------------------------------------------------


  ! ----------------------------------------------------------
  !
  ! ----------------------------------------------------------
  subroutine update_hamiltonian_elec_psi(iter, namespace, space, gr, ks, hm, ext_partners, td, tg, par, st, ions)
    integer,                  intent(in)    :: iter
    type(namespace_t),        intent(in)    :: namespace
    type(electron_space_t),   intent(in)    :: space
    type(grid_t),             intent(inout) :: gr
    type(v_ks_t),             intent(inout) :: ks
    type(hamiltonian_elec_t), intent(inout) :: hm
    type(partner_list_t),     intent(in)    :: ext_partners
    type(td_t),               intent(inout) :: td
    type(target_t),           intent(inout) :: tg
    type(controlfunction_t),  intent(in)    :: par
    type(states_elec_t),      intent(inout) :: st
    type(ions_t),             intent(in)    :: ions

    integer :: j

    PUSH_SUB(update_hamiltonian_elec_psi)

    if (target_mode(tg) == oct_targetmode_td) then
      call hamiltonian_elec_remove_inh(hm)
    end if

    if (td%ions_dyn%ions_move()) then
      call hamiltonian_elec_remove_inh(hm)
    end if

    if (hm%theory_level /= INDEPENDENT_PARTICLES .and. (.not. ks%frozen_hxc)) then
      call oct_exchange_remove(hm%oct_exchange)
    end if

    call hamiltonian_elec_not_adjoint(hm)

    do j = iter - 2, iter + 2
      if (j >= 0 .and. j <= td%max_iter) then
        call controlfunction_to_h_val(par, ext_partners, j+1)
      end if
    end do
    if (hm%theory_level /= INDEPENDENT_PARTICLES .and. (.not. ks%frozen_hxc)) then
      call density_calc(st, gr, st%rho)
      call v_ks_calc(ks, namespace, space, hm, st, ions, ext_partners)
      call hm%update(gr, namespace, space, ext_partners)
    end if

    POP_SUB(update_hamiltonian_elec_psi)
  end subroutine update_hamiltonian_elec_psi
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine calculate_g(space, gr, hm, lasers, psi, chi, dl, dq)
    class(space_t),                 intent(in)    :: space
    type(grid_t),                   intent(inout) :: gr
    type(hamiltonian_elec_t),       intent(in)    :: hm
    type(lasers_t),                 intent(in)    :: lasers
    type(states_elec_t),            intent(inout) :: psi
    type(states_elec_t),            intent(inout) :: chi
    complex(real64),                intent(inout) :: dl(:), dq(:)

    complex(real64), allocatable :: zpsi(:, :), zoppsi(:, :)
    integer :: no_parameters, j, ik, p

    PUSH_SUB(calculate_g)

    no_parameters = lasers%no_lasers

    SAFE_ALLOCATE(zpsi(1:gr%np_part, 1:chi%d%dim))
    SAFE_ALLOCATE(zoppsi(1:gr%np_part, 1:chi%d%dim))

    do j = 1, no_parameters

      dl(j) = M_z0
      do ik = 1, psi%nik
        do p = 1, psi%nst

          call states_elec_get_state(psi, gr, p, ik, zpsi)

          zoppsi = M_z0
          if (allocated(hm%ep%a_static)) then
            call vlaser_operator_linear(lasers%lasers(j), gr%der, hm%d, zpsi, &
              zoppsi, ik, hm%ep%gyromagnetic_ratio, hm%ep%a_static)
          else
            call vlaser_operator_linear(lasers%lasers(j), gr%der, hm%d, zpsi, &
              zoppsi, ik, hm%ep%gyromagnetic_ratio)
          end if

          call states_elec_get_state(chi, gr, p, ik, zpsi)
          dl(j) = dl(j) + zmf_dotp(gr, psi%d%dim, zpsi, zoppsi)
        end do
      end do

      ! The quadratic part should only be computed if necessary.
      if (laser_kind(lasers%lasers(j)) == E_FIELD_MAGNETIC) then

        dq(j) = M_z0
        do ik = 1, psi%nik
          do p = 1, psi%nst
            zoppsi = M_z0

            call states_elec_get_state(psi, gr, p, ik, zpsi)
            call vlaser_operator_quadratic(lasers%lasers(j), gr, space, zpsi, zoppsi)

            call states_elec_get_state(chi, gr, p, ik, zpsi)
            dq(j) = dq(j) + zmf_dotp(gr, psi%d%dim, zpsi, zoppsi)

          end do
        end do

      else
        dq(j) = M_z0
      end if
    end do

    SAFE_DEALLOCATE_A(zpsi)
    SAFE_DEALLOCATE_A(zoppsi)

    POP_SUB(calculate_g)
  end subroutine calculate_g
  ! ---------------------------------------------------------




  !> Calculates the value of the control functions at iteration
  !! iter, from the state psi and the Lagrange-multiplier chi.
  !!
  !! If dir = 'f', the field must be updated for a forward
  !! propagation. In that case, the propagation step that is
  !! going to be done moves from (iter-1)*|dt| to iter*|dt|.
  !!
  !! If dir = 'b', the field must be updated for a backward
  !! propagation. In taht case, the propagation step that is
  !! going to be done moves from iter*|dt| to (iter-1)*|dt|.
  !!
  !! cp = (1-eta)*cpp - (eta/alpha) * <chi|V|Psi>
  subroutine update_field(iter, cp, space, gr, hm, ext_partners, ions, qcpsi, qcchi, cpp, dir)
    class(space_t),            intent(in)    :: space
    integer,                   intent(in)    :: iter
    type(controlfunction_t),   intent(inout) :: cp
    type(grid_t),              intent(inout) :: gr
    type(hamiltonian_elec_t),  intent(in)    :: hm
    type(partner_list_t),      intent(in)    :: ext_partners
    type(ions_t),              intent(in)    :: ions
    type(opt_control_state_t), intent(inout) :: qcpsi
    type(opt_control_state_t), intent(inout) :: qcchi
    type(controlfunction_t),   intent(in)    :: cpp
    character(len=1),          intent(in)    :: dir

    complex(real64) :: d1, pol(3)
    complex(real64), allocatable  :: dl(:), dq(:), zpsi(:, :), zchi(:, :)
    real(real64), allocatable :: d(:)
    integer :: j, no_parameters, iatom
    type(states_elec_t), pointer :: psi, chi
    real(real64), pointer :: q(:, :)
    type(lasers_t), pointer :: lasers

    PUSH_SUB(update_field)

    psi => opt_control_point_qs(qcpsi)
    chi => opt_control_point_qs(qcchi)
    q => opt_control_point_q(qcchi)

    no_parameters = controlfunction_number(cp)

    SAFE_ALLOCATE(dl(1:no_parameters))
    SAFE_ALLOCATE(dq(1:no_parameters))
    SAFE_ALLOCATE( d(1:no_parameters))


    lasers => list_get_lasers(ext_partners)
    if(associated(lasers)) then
      call calculate_g(space, gr, hm, lasers, psi, chi, dl, dq)
    end if

    d1 = M_z1
    if (zbr98_) then
      SAFE_ALLOCATE(zpsi(1:gr%np, 1:psi%d%dim))
      SAFE_ALLOCATE(zchi(1:gr%np, 1:chi%d%dim))

      call states_elec_get_state(psi, gr, 1, 1, zpsi)
      call states_elec_get_state(chi, gr, 1, 1, zchi)

      d1 = zmf_dotp(gr, psi%d%dim, zpsi, zchi)
      do j = 1, no_parameters
        d(j) = aimag(d1*dl(j)) / controlfunction_alpha(cp, j)
      end do

      SAFE_DEALLOCATE_A(zpsi)
      SAFE_DEALLOCATE_A(zchi)

    elseif (gradients_) then
      do j = 1, no_parameters
        d(j) = M_TWO * aimag(dl(j))
      end do
    else
      do j = 1, no_parameters
        d(j) = aimag(dl(j)) / controlfunction_alpha(cp, j)
      end do
    end if

    ! This is for the classical target.
    if (dir == 'b' .and. associated(lasers)) then
      pol = laser_polarization(lasers%lasers(1))
      do iatom = 1, ions%natoms
        d(1) = d(1) - ions%charge(iatom) * real(sum(pol(1:ions%space%dim)*q(iatom, 1:ions%space%dim)), real64)
      end do
    end if


    if (dir == 'f') then
      call controlfunction_update(cp, cpp, dir, iter, delta_, d, dq)
    else
      call controlfunction_update(cp, cpp, dir, iter, eta_, d, dq)
    end if

    nullify(q)
    nullify(psi)
    nullify(chi)
    SAFE_DEALLOCATE_A(d)
    SAFE_DEALLOCATE_A(dl)
    SAFE_DEALLOCATE_A(dq)
    POP_SUB(update_field)
  end subroutine update_field
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine oct_prop_init(prop, namespace, dirname, mesh, mc)
    type(oct_prop_t),  intent(inout) :: prop
    type(namespace_t), intent(in)    :: namespace
    character(len=*),  intent(in)    :: dirname
    class(mesh_t),     intent(in)    :: mesh
    type(multicomm_t), intent(in)    :: mc

    integer :: j, ierr

    PUSH_SUB(oct_prop_init)

    prop%dirname = dirname
    prop%niter = niter_
    prop%number_checkpoints = number_checkpoints_

    ! The OCT_DIR//trim(dirname) will be used to write and read information during the calculation,
    ! so they need to use the same path.
    call restart_init(prop%restart_dump, namespace, RESTART_OCT, RESTART_TYPE_DUMP, mc, ierr, mesh=mesh)
    call restart_init(prop%restart_load, namespace, RESTART_OCT, RESTART_TYPE_LOAD, mc, ierr, mesh=mesh)

    SAFE_ALLOCATE(prop%iter(1:prop%number_checkpoints+2))
    prop%iter(1) = 0
    do j = 1, prop%number_checkpoints
      prop%iter(j+1) = nint(real(niter_, real64) /(prop%number_checkpoints+1) * j)
    end do
    prop%iter(prop%number_checkpoints+2) = niter_

    POP_SUB(oct_prop_init)
  end subroutine oct_prop_init
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine oct_prop_end(prop)
    type(oct_prop_t), intent(inout) :: prop

    PUSH_SUB(oct_prop_end)

    call restart_end(prop%restart_load)
    call restart_end(prop%restart_dump)

    SAFE_DEALLOCATE_A(prop%iter)
    ! This routine should maybe delete the files?

    POP_SUB(oct_prop_end)
  end subroutine oct_prop_end
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine oct_prop_check(prop, namespace, space, psi, mesh, kpoints, iter)
    type(oct_prop_t),    intent(inout) :: prop
    type(namespace_t),   intent(in)    :: namespace
    class(space_t),      intent(in)    :: space
    type(states_elec_t), intent(inout) :: psi
    class(mesh_t),       intent(in)    :: mesh
    type(kpoints_t),     intent(in)    :: kpoints
    integer,             intent(in)    :: iter

    type(states_elec_t) :: stored_st
    character(len=80) :: dirname
    integer :: j, ierr
    complex(real64) :: overlap, prev_overlap
    real(real64), parameter :: WARNING_THRESHOLD = 1.0e-2_real64

    PUSH_SUB(oct_prop_check)

    do j = 1, prop%number_checkpoints + 2
      if (prop%iter(j) == iter) then
        call states_elec_copy(stored_st, psi)
        write(dirname,'(a, i4.4)') trim(prop%dirname), j
        call restart_open_dir(prop%restart_load, dirname, ierr)
        if (ierr == 0) then
          call states_elec_load(prop%restart_load, namespace, space, stored_st, mesh, kpoints, ierr, verbose=.false.)
        end if
        if (ierr /= 0) then
          message(1) = "Unable to read wavefunctions from '"//trim(dirname)//"'."
          call messages_fatal(1, namespace=namespace)
        end if
        call restart_close_dir(prop%restart_load)
        prev_overlap = zstates_elec_mpdotp(namespace, mesh, stored_st, stored_st)
        overlap = zstates_elec_mpdotp(namespace, mesh, stored_st, psi)
        if (abs(overlap - prev_overlap) > WARNING_THRESHOLD) then
          write(message(1), '(a,es13.4)') &
            "Forward-backward propagation produced an error of", abs(overlap-prev_overlap)
          write(message(2), '(a,i8)') "Iter = ", iter
          call messages_warning(2, namespace=namespace)
        end if
        ! Restore state only if the number of checkpoints is larger than zero.
        if (prop%number_checkpoints > 0) then
          call states_elec_end(psi)
          call states_elec_copy(psi, stored_st)
        end if
        call states_elec_end(stored_st)
      end if
    end do
    POP_SUB(oct_prop_check)
  end subroutine oct_prop_check
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine oct_prop_dump_states(prop, space, iter, psi, mesh, kpoints, ierr)
    type(oct_prop_t),    intent(inout) :: prop
    class(space_t),      intent(in)    :: space
    integer,             intent(in)    :: iter
    type(states_elec_t), intent(in)    :: psi
    class(mesh_t),       intent(in)    :: mesh
    type(kpoints_t),     intent(in)    :: kpoints
    integer,             intent(out)   :: ierr

    integer :: j, err
    character(len=80) :: dirname

    PUSH_SUB(oct_prop_dump_states)

    ierr = 0

    if (restart_skip(prop%restart_dump)) then
      POP_SUB(oct_prop_dump_states)
      return
    end if

    message(1) = "Debug: Writing OCT propagation states restart."
    call messages_info(1, debug_only=.true.)

    do j = 1, prop%number_checkpoints + 2
      if (prop%iter(j) == iter) then
        write(dirname,'(a,i4.4)') trim(prop%dirname), j
        call restart_open_dir(prop%restart_dump, dirname, err)
        if (err == 0) then
          call states_elec_dump(prop%restart_dump, space, psi, mesh, kpoints, err, iter, verbose = .false.)
        end if
        if (err /= 0) then
          message(1) = "Unable to write wavefunctions to '"//trim(dirname)//"'."
          call messages_warning(1)
          ierr = ierr + 2**j
        end if
        call restart_close_dir(prop%restart_dump)
      end if
    end do

    message(1) = "Debug: Writing OCT propagation states restart done."
    call messages_info(1, debug_only=.true.)

    POP_SUB(oct_prop_dump_states)
  end subroutine oct_prop_dump_states
  ! ---------------------------------------------------------


  ! ---------------------------------------------------------
  subroutine oct_prop_load_states(prop, namespace, space, psi, mesh, kpoints, iter, ierr)
    type(oct_prop_t),    intent(inout) :: prop
    type(namespace_t),   intent(in)    :: namespace
    class(space_t),      intent(in)    :: space
    type(states_elec_t), intent(inout) :: psi
    class(mesh_t),       intent(in)    :: mesh
    type(kpoints_t),     intent(in)    :: kpoints
    integer,             intent(in)    :: iter
    integer,             intent(out)   :: ierr

    integer :: j, err
    character(len=80) :: dirname

    PUSH_SUB(oct_prop_load_states)

    ierr = 0

    if (restart_skip(prop%restart_load)) then
      ierr = -1
      POP_SUB(oct_prop_load_states)
      return
    end if

    message(1) = "Debug: Reading OCT propagation states restart."
    call messages_info(1, namespace=namespace, debug_only=.true.)

    do j = 1, prop%number_checkpoints + 2
      if (prop%iter(j) == iter) then
        write(dirname,'(a, i4.4)') trim(prop%dirname), j
        call restart_open_dir(prop%restart_load, dirname, err)
        if (err == 0) then
          call states_elec_load(prop%restart_load, namespace, space, psi, mesh, kpoints, err, verbose=.false.)
        end if
        if (err /= 0) then
          message(1) = "Unable to read wavefunctions from '"//trim(dirname)//"'."
          call messages_warning(1, namespace=namespace)
          ierr = ierr + 2**j
        end if
        call restart_close_dir(prop%restart_load)
      end if
    end do

    message(1) = "Debug: Reading OCT propagation states restart done."
    call messages_info(1, namespace=namespace, debug_only=.true.)

    POP_SUB(oct_prop_load_states)
  end subroutine oct_prop_load_states
  ! ---------------------------------------------------------

  ! ---------------------------------------------------------
  subroutine vlaser_operator_quadratic(laser, mesh, space, psi, hpsi)
    type(laser_t),       intent(in)    :: laser
    class(mesh_t),       intent(in)    :: mesh
    class(space_t),      intent(in)    :: space
    complex(real64),     intent(inout) :: psi(:,:)  !< psi(der%mesh%np_part, h%d%dim)
    complex(real64),     intent(inout) :: hpsi(:,:) !< hpsi(der%mesh%np_part, h%d%dim)

    integer :: ip
    logical :: vector_potential, magnetic_field

    real(real64) :: a_field(3), a_field_prime(3), bb(3), b_prime(3)
    real(real64), allocatable :: aa(:, :), a_prime(:, :)

    PUSH_SUB(vlaser_operator_quadratic)

    a_field = M_ZERO

    vector_potential = .false.
    magnetic_field = .false.

    select case (laser_kind(laser))
    case (E_FIELD_ELECTRIC) ! do nothing
    case (E_FIELD_MAGNETIC)
      if (.not. allocated(aa)) then
        SAFE_ALLOCATE(aa(1:mesh%np_part, 1:space%dim))
        aa = M_ZERO
        SAFE_ALLOCATE(a_prime(1:mesh%np_part, 1:space%dim))
      end if
      a_prime = M_ZERO
      call laser_vector_potential(laser, mesh, a_prime)
      aa = aa + a_prime
      b_prime = M_ZERO
      call laser_field(laser, b_prime(1:space%dim))
      bb = bb + b_prime
      magnetic_field = .true.
    case (E_FIELD_VECTOR_POTENTIAL)
      a_field_prime = M_ZERO
      call laser_field(laser, a_field_prime(1:space%dim))
      a_field = a_field + a_field_prime
      vector_potential = .true.
    end select

    if (magnetic_field) then
      do ip = 1, mesh%np
        hpsi(ip, :) = hpsi(ip, :) + M_HALF * &
          dot_product(aa(ip, 1:space%dim), aa(ip, 1:space%dim)) * psi(ip, :) / P_c**2
      end do
      SAFE_DEALLOCATE_A(aa)
      SAFE_DEALLOCATE_A(a_prime)
    end if
    if (vector_potential) then
      do ip = 1, mesh%np
        hpsi(ip, :) = hpsi(ip, :) + M_HALF * &
          dot_product(a_field(1:space%dim), a_field(1:space%dim))*psi(ip, :) / P_c**2
      end do
    end if

    POP_SUB(vlaser_operator_quadratic)
  end subroutine vlaser_operator_quadratic

  ! ---------------------------------------------------------
  subroutine vlaser_operator_linear(laser, der, std, psi, hpsi, ik, gyromagnetic_ratio, a_static)
    type(laser_t),               intent(in)      :: laser
    type(derivatives_t),         intent(in)      :: der
    type(states_elec_dim_t),     intent(in)      :: std
    complex(real64), contiguous, intent(inout)   :: psi(:,:)
    complex(real64),             intent(inout)   :: hpsi(:,:)
    integer,                     intent(in)      :: ik
    real(real64),                intent(in)      :: gyromagnetic_ratio
    real(real64), optional,      intent(in)      :: a_static(:,:)

    integer :: ip, idim
    logical :: electric_field, vector_potential, magnetic_field
    complex(real64), allocatable :: grad(:, :, :), lhpsi(:, :)

    real(real64) :: a_field(3), a_field_prime(3), bb(3), b_prime(3)

    real(real64), allocatable :: vv(:), pot(:), aa(:, :), a_prime(:, :)

    PUSH_SUB(vlaser_operator_linear)

    a_field = M_ZERO

    electric_field = .false.
    vector_potential = .false.
    magnetic_field = .false.

    select case (laser_kind(laser))
    case (E_FIELD_SCALAR_POTENTIAL)
      if (.not. allocated(vv)) then
        SAFE_ALLOCATE(vv(1:der%mesh%np))
      end if
      vv = M_ZERO
      call laser_potential(laser, der%mesh, vv)
      electric_field = .true.

    case (E_FIELD_ELECTRIC)
      if (.not. allocated(vv)) then
        SAFE_ALLOCATE(vv(1:der%mesh%np))
        vv = M_ZERO
        SAFE_ALLOCATE(pot(1:der%mesh%np))
      end if
      pot = M_ZERO
      call laser_potential(laser, der%mesh, pot)
      call lalg_axpy(der%mesh%np, M_ONE, pot, vv)
      electric_field = .true.
      SAFE_DEALLOCATE_A(pot)

    case (E_FIELD_MAGNETIC)
      if (.not. allocated(aa)) then
        SAFE_ALLOCATE(aa(1:der%mesh%np_part, 1:der%dim))
        aa = M_ZERO
        SAFE_ALLOCATE(a_prime(1:der%mesh%np_part, 1:der%dim))
      end if
      a_prime = M_ZERO
      call laser_vector_potential(laser, der%mesh, a_prime)
      aa = aa + a_prime
      b_prime = M_ZERO
      call laser_field(laser, b_prime(1:der%dim))
      bb = bb + b_prime
      magnetic_field = .true.
    case (E_FIELD_VECTOR_POTENTIAL)
      a_field_prime = M_ZERO
      call laser_field(laser, a_field_prime(1:der%dim))
      a_field = a_field + a_field_prime
      vector_potential = .true.
    end select

    if (electric_field) then
      do idim = 1, std%dim
        hpsi(1:der%mesh%np, idim)= hpsi(1:der%mesh%np, idim) + vv(1:der%mesh%np) * psi(1:der%mesh%np, idim)
      end do
      SAFE_DEALLOCATE_A(vv)
    end if

    if (magnetic_field) then
      SAFE_ALLOCATE(grad(1:der%mesh%np_part, 1:der%dim, 1:std%dim))

      do idim = 1, std%dim
        call zderivatives_grad(der, psi(:, idim), grad(:, :, idim))
      end do

      ! If there is a static magnetic field, its associated vector potential is coupled with
      ! the time-dependent one defined as a "laser" (ideally one should just add them all and
      ! do the calculation only once...). Note that h%ep%a_static already has been divided
      ! by P_c, and therefore here we only divide by P_c, and not P_c**2.
      !
      ! We put a minus sign, since for the moment vector potential for
      ! lasers and for the static magnetic field use a different
      ! convetion.
      if (present(a_static)) then
        do ip = 1, der%mesh%np
          hpsi(ip, :) = hpsi(ip, :) - dot_product(aa(ip, 1:der%dim), a_static(ip, 1:der%dim)) * psi(ip, :) / P_c
        end do
      end if

      select case (std%ispin)
      case (UNPOLARIZED, SPIN_POLARIZED)
        do ip = 1, der%mesh%np
          hpsi(ip, 1) = hpsi(ip, 1) - M_zI * dot_product(aa(ip, 1:der%dim), grad(ip, 1:der%dim, 1)) / P_c
        end do
      case (SPINORS)
        do ip = 1, der%mesh%np
          do idim = 1, std%dim
            hpsi(ip, idim) = hpsi(ip, idim) - M_zI * &
              dot_product(aa(ip, 1:der%dim), grad(ip, 1:der%dim, idim)) / P_c
          end do
        end do
      end select


      select case (std%ispin)
      case (SPIN_POLARIZED)
        SAFE_ALLOCATE(lhpsi(1:der%mesh%np, 1:std%dim))
        if (modulo(ik+1, 2) == 0) then ! we have a spin down
          lhpsi(1:der%mesh%np, 1) = - M_HALF / P_c * norm2(bb) * psi(1:der%mesh%np, 1)
        else
          lhpsi(1:der%mesh%np, 1) = + M_HALF / P_c * norm2(bb) * psi(1:der%mesh%np, 1)
        end if
        hpsi(1:der%mesh%np, :) = hpsi(1:der%mesh%np, :) + (gyromagnetic_ratio * M_HALF) * lhpsi(1:der%mesh%np, :)
        SAFE_DEALLOCATE_A(lhpsi)

      case (SPINORS)
        SAFE_ALLOCATE(lhpsi(1:der%mesh%np, 1:std%dim))
        lhpsi(1:der%mesh%np, 1) = M_HALF / P_c * (bb(3) * psi(1:der%mesh%np, 1) &
          + cmplx(bb(1), -bb(2), real64) * psi(1:der%mesh%np, 2))
        lhpsi(1:der%mesh%np, 2) = M_HALF / P_c * (-bb(3) * psi(1:der%mesh%np, 2) &
          + cmplx(bb(1), bb(2), real64) * psi(1:der%mesh%np, 1))
        hpsi(1:der%mesh%np, :) = hpsi(1:der%mesh%np, :) + (gyromagnetic_ratio * M_HALF) * lhpsi(1:der%mesh%np, :)
        SAFE_DEALLOCATE_A(lhpsi)
      end select

      SAFE_DEALLOCATE_A(grad)
      SAFE_DEALLOCATE_A(aa)
      SAFE_DEALLOCATE_A(a_prime)
    end if

    if (vector_potential) then
      SAFE_ALLOCATE(grad(1:der%mesh%np_part, 1:der%dim, 1:std%dim))

      do idim = 1, std%dim
        call zderivatives_grad(der, psi(:, idim), grad(:, :, idim))
      end do

      select case (std%ispin)
      case (UNPOLARIZED, SPIN_POLARIZED)
        do ip = 1, der%mesh%np
          hpsi(ip, 1) = hpsi(ip, 1) - M_zI * dot_product(a_field(1:der%dim), grad(ip, 1:der%dim, 1)) / P_c
        end do
      case (SPINORS)
        do ip = 1, der%mesh%np
          do idim = 1, std%dim
            hpsi(ip, idim) = hpsi(ip, idim) - M_zI * &
              dot_product(a_field(1:der%dim), grad(ip, 1:der%dim, idim)) / P_c
          end do
        end do
      end select
      SAFE_DEALLOCATE_A(grad)
    end if

    POP_SUB(vlaser_operator_linear)
  end subroutine vlaser_operator_linear


end module propagation_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
