PISM, A Parallel Ice Sheet Model  stable v2.1-1-g6902d5502 committed by Ed Bueler on 2023-12-20 08:38:27 -0800
IP_SSATaucTikhonovGNSolver.cc
Go to the documentation of this file.
1 // Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021, 2022, 2023 David Maxwell and Constantine Khroulev
2 //
3 // This file is part of PISM.
4 //
5 // PISM is free software; you can redistribute it and/or modify it under the
6 // terms of the GNU General Public License as published by the Free Software
7 // Foundation; either version 3 of the License, or (at your option) any later
8 // version.
9 //
10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 // details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with PISM; if not, write to the Free Software
17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 
19 #include "pism/inverse/IP_SSATaucTikhonovGNSolver.hh"
20 #include "pism/util/TerminationReason.hh"
21 #include "pism/util/pism_options.hh"
22 #include "pism/util/ConfigInterface.hh"
23 #include "pism/util/Grid.hh"
24 #include "pism/util/Context.hh"
25 #include "pism/util/petscwrappers/Vec.hh"
26 
27 namespace pism {
28 namespace inverse {
29 
31  DesignVec &d0, StateVec &u_obs, double eta,
32  IPInnerProductFunctional<DesignVec> &designFunctional,
33  IPInnerProductFunctional<StateVec> &stateFunctional)
34  : m_design_stencil_width(d0.stencil_width()),
35  m_state_stencil_width(u_obs.stencil_width()),
36  m_ssaforward(ssaforward),
37  m_x(d0.grid(), "x"),
38  m_tmp_D1Global(d0.grid(), "work vector"),
39  m_tmp_D2Global(d0.grid(), "work vector"),
40  m_tmp_D1Local(d0.grid(), "work vector"),
41  m_tmp_D2Local(d0.grid(), "work vector"),
42  m_tmp_S1Global(d0.grid(), "work vector"),
43  m_tmp_S2Global(d0.grid(), "work vector"),
44  m_tmp_S1Local(d0.grid(), "work vector"),
45  m_tmp_S2Local(d0.grid(), "work vector"),
46  m_GN_rhs(d0.grid(), "GN_rhs"),
47  m_d0(d0),
48  m_dGlobal(d0.grid(), "d (sans ghosts)"),
49  m_d_diff(d0.grid(), "d_diff"),
50  m_d_diff_lin(d0.grid(), "d_diff linearized"),
51  m_h(d0.grid(), "h"),
52  m_hGlobal(d0.grid(), "h (sans ghosts)"),
53  m_dalpha_rhs(d0.grid(), "dalpha rhs"),
54  m_dh_dalpha(d0.grid(), "dh_dalpha"),
55  m_dh_dalphaGlobal(d0.grid(), "dh_dalpha"),
56  m_grad_design(d0.grid(), "grad design"),
57  m_grad_state(d0.grid(), "grad design"),
58  m_gradient(d0.grid(), "grad design"),
59  m_u_obs(u_obs),
60  m_u_diff(d0.grid(), "du"),
61  m_eta(eta),
62  m_designFunctional(designFunctional),
63  m_stateFunctional(stateFunctional),
64  m_target_misfit(0.0)
65 {
66  PetscErrorCode ierr;
67  std::shared_ptr<const Grid> grid = m_d0.grid();
68  m_comm = grid->com;
69 
70  m_d = std::make_shared<DesignVecGhosted>(grid, "d");
71 
72  ierr = KSPCreate(grid->com, m_ksp.rawptr());
73  PISM_CHK(ierr, "KSPCreate");
74 
75  ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
76  PISM_CHK(ierr, "KSPSetOptionsPrefix");
77 
78  double ksp_rtol = 1e-5; // Soft tolerance
79  ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
80  PISM_CHK(ierr, "KSPSetTolerances");
81 
82  ierr = KSPSetType(m_ksp, KSPCG);
83  PISM_CHK(ierr, "KSPSetType");
84 
85  PC pc;
86  ierr = KSPGetPC(m_ksp, &pc);
87  PISM_CHK(ierr, "KSPGetPC");
88 
89  ierr = PCSetType(pc, PCNONE);
90  PISM_CHK(ierr, "PCSetType");
91 
92  ierr = KSPSetFromOptions(m_ksp);
93  PISM_CHK(ierr, "KSPSetFromOptions");
94 
95  int nLocalNodes = grid->xm()*grid->ym();
96  int nGlobalNodes = grid->Mx()*grid->My();
97  ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
98  nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
99  PISM_CHK(ierr, "MatCreateShell");
100 
103  multCallback::connect(m_mat_GN);
104 
105  m_alpha = 1./m_eta;
106  m_logalpha = log(m_alpha);
107 
108  m_tikhonov_adaptive = options::Bool("-tikhonov_adaptive", "Tikhonov adaptive");
109 
110  m_iter_max = 1000;
111  m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
112 
113  m_tikhonov_atol = grid->ctx()->config()->get_number("inverse.tikhonov.atol");
114  m_tikhonov_rtol = grid->ctx()->config()->get_number("inverse.tikhonov.rtol");
115  m_tikhonov_ptol = grid->ctx()->config()->get_number("inverse.tikhonov.ptol");
116 
117  m_log = d0.grid()->ctx()->log();
118 }
119 
120 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::init() {
122 }
123 
125  this->apply_GN(x.vec(), y.vec());
126 }
127 
128 //! @note This function has to return PetscErrorCode (it is used as a callback).
130  StateVec &tmp_gS = m_tmp_S1Global;
131  StateVec &Tx = m_tmp_S1Local;
132  DesignVec &tmp_gD = m_tmp_D1Global;
133  DesignVec &GNx = m_tmp_D2Global;
134 
135  PetscErrorCode ierr;
136  // FIXME: Needless copies for now.
137  {
138  ierr = DMGlobalToLocalBegin(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
139  PISM_CHK(ierr, "DMGlobalToLocalBegin");
140 
141  ierr = DMGlobalToLocalEnd(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
142  PISM_CHK(ierr, "DMGlobalToLocalEnd");
143  }
144 
146  Tx.update_ghosts();
147 
148  m_stateFunctional.interior_product(Tx,tmp_gS);
149 
151 
152  m_designFunctional.interior_product(m_x,tmp_gD);
153  GNx.add(m_alpha,tmp_gD);
154 
155  ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
156 }
157 
159 
160  rhs.set(0);
161 
162  m_stateFunctional.interior_product(m_u_diff,m_tmp_S1Global);
164 
165  m_designFunctional.interior_product(m_d_diff,m_tmp_D1Global);
167 
168  rhs.scale(-1);
169 }
170 
171 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve_linearized() {
172  PetscErrorCode ierr;
173 
174  this->assemble_GN_rhs(m_GN_rhs);
175 
176  ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
177  PISM_CHK(ierr, "KSPSetOperators");
178 
179  ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
180  PISM_CHK(ierr, "KSPSolve");
181 
182  KSPConvergedReason ksp_reason;
183  ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
184  PISM_CHK(ierr, "KSPGetConvergedReason");
185 
187 
188  return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
189 }
190 
192 
196 
197  double sValue;
198  m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
199 
201  m_tmp_D1Local.add(1,h);
202 
203  double dValue;
204  m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
205 
206  *value = m_alpha*dValue + sValue;
207 }
208 
209 
210 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::check_convergence() {
211 
212  double designNorm, stateNorm, sumNorm;
213  double dWeight, sWeight;
214  dWeight = m_alpha;
215  sWeight = 1;
216 
217  designNorm = m_grad_design.norm(NORM_2)[0];
218  stateNorm = m_grad_state.norm(NORM_2)[0];
219 
220  designNorm *= dWeight;
221  stateNorm *= sWeight;
222 
223  sumNorm = m_gradient.norm(NORM_2)[0];
224 
225  m_log->message(2,
226  "----------------------------------------------------------\n");
227  m_log->message(2,
228  "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
230  if (m_tikhonov_adaptive) {
231  m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
232  }
233  double relsum = (sumNorm/std::max(designNorm,stateNorm));
234  m_log->message(2,
235  "design norm %g stateNorm %g sum %g; relative difference %g\n",
236  designNorm, stateNorm, sumNorm, relsum);
237 
238  // If we have an adaptive tikhonov parameter, check if we have met
239  // this constraint first.
240  if (m_tikhonov_adaptive) {
241  double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
242  if (disc_ratio > m_tikhonov_ptol) {
244  }
245  }
246 
247  if (sumNorm < m_tikhonov_atol) {
248  return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
249  }
250 
251  if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
252  return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
253  }
254 
255  if (m_iter>m_iter_max) {
257  } else {
259  }
260 }
261 
263 
264  std::shared_ptr<TerminationReason> reason = m_ssaforward.linearize_at(*m_d);
265  if (reason->failed()) {
266  return reason;
267  }
268 
270  m_d_diff.add(-1,m_d0);
271 
273  m_u_diff.add(-1,m_u_obs);
274 
276 
277  // The following computes the reduced gradient.
278  StateVec &adjointRHS = m_tmp_S1Global;
279  m_stateFunctional.gradientAt(m_u_diff,adjointRHS);
281 
285 
286  double valDesign, valState;
287  m_designFunctional.valueAt(m_d_diff,&valDesign);
288  m_stateFunctional.valueAt(m_u_diff,&valState);
289 
290  m_val_design = valDesign;
291  m_val_state = valState;
292 
293  m_value = valDesign * m_alpha + valState;
294 
295  return reason;
296 }
297 
298 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::linesearch() {
299  PetscErrorCode ierr;
300 
301  std::shared_ptr<TerminationReason> step_reason;
302 
303  double old_value = m_val_design * m_alpha + m_val_state;
304 
305  double descent_derivative;
306 
308 
309  ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
310  PISM_CHK(ierr, "VecDot");
311 
312  if (descent_derivative >=0) {
313  printf("descent derivative: %g\n",descent_derivative);
314  return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Not descent direction"));
315  }
316 
317  double alpha = 1;
319  while(true) {
320  m_d->add(alpha,m_h); // Replace with line search.
321  step_reason = this->evaluate_objective_and_gradient();
322  if (step_reason->succeeded()) {
323  if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
324  break;
325  }
326  }
327  else {
328  printf("forward solve failed in linsearch. Shrinking.\n");
329  }
330  alpha *=.5;
331  if (alpha<1e-20) {
332  printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
333  return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Too many step shrinks."));
334  }
335  m_d->copy_from(m_tmp_D1Local);
336  }
337 
339 }
340 
341 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve() {
342 
343  if (m_target_misfit == 0) {
344  throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
345  " IP_SSATaucTikhonovGNSolver::solve.");
346  }
347 
348  m_iter = 0;
349  m_d->copy_from(m_d0);
350 
351  double dlogalpha = 0;
352 
353  std::shared_ptr<TerminationReason> step_reason, reason;
354 
355  step_reason = this->evaluate_objective_and_gradient();
356  if (step_reason->failed()) {
357  reason.reset(new GenericTerminationReason(-1,"Forward solve"));
358  reason->set_root_cause(step_reason);
359  return reason;
360  }
361 
362  while(true) {
363 
364  reason = this->check_convergence();
365  if (reason->done()) {
366  return reason;
367  }
368 
369  if (m_tikhonov_adaptive) {
370  m_logalpha += dlogalpha;
371  m_alpha = exp(m_logalpha);
372  }
373 
374  step_reason = this->solve_linearized();
375  if (step_reason->failed()) {
376  reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
377  reason->set_root_cause(step_reason);
378  return reason;
379  }
380 
381  step_reason = this->linesearch();
382  if (step_reason->failed()) {
383  std::shared_ptr<TerminationReason> cause = reason;
384  reason.reset(new GenericTerminationReason(-1,"Linesearch"));
385  reason->set_root_cause(step_reason);
386  return reason;
387  }
388 
389  if (m_tikhonov_adaptive) {
390  step_reason = this->compute_dlogalpha(&dlogalpha);
391  if (step_reason->failed()) {
392  std::shared_ptr<TerminationReason> cause = reason;
393  reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
394  reason->set_root_cause(step_reason);
395  return reason;
396  }
397  }
398 
399  m_iter++;
400  }
401 
402  return reason;
403 }
404 
405 std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::compute_dlogalpha(double *dlogalpha) {
406 
407  PetscErrorCode ierr;
408 
409  // Compute the right-hand side for computing dh/dalpha.
411  m_d_diff_lin.add(1,m_h);
412  m_designFunctional.interior_product(m_d_diff_lin,m_dalpha_rhs);
413  m_dalpha_rhs.scale(-1);
414 
415  // Solve linear equation for dh/dalpha.
416  ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
417  PISM_CHK(ierr, "KSPSetOperators");
418 
419  ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
420  PISM_CHK(ierr, "KSPSolve");
421 
423 
424  KSPConvergedReason ksp_reason;
425  ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
426  PISM_CHK(ierr, "KSPGetConvergedReason");
427 
428  if (ksp_reason<0) {
429  return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
430  }
431 
432  // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
436 
437  // Compute linearized discrepancy.
438  double disc_sq;
440 
441  // There are a number of equivalent ways to compute the derivative of the
442  // linearized discrepancy with respect to alpha, some of which are cheaper
443  // than others to compute. This equivalency relies, however, on having an
444  // exact solution in the Gauss-Newton step. Since we only solve this with
445  // a soft tolerance, we lose equivalency. We attempt a cheap computation,
446  // and then do a sanity check (namely that the derivative is positive).
447  // If this fails, we compute by a harder way that inherently yields a
448  // positive number.
449 
450  double ddisc_sq_dalpha;
451  m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
452  ddisc_sq_dalpha *= -2*m_alpha;
453 
454  if (ddisc_sq_dalpha <= 0) {
455  // Try harder.
456 
457  m_log->message(3,
458  "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
459  " Tighten inv_gn_ksp_rtol?\n",
460  ddisc_sq_dalpha);
461 
462  // S2Local contains T(dh/dalpha)
465 
466  double ddisc_sq_dalpha_a;
467  m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
468  double ddisc_sq_dalpha_b;
469  m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
470  ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
471 
472  m_log->message(3,
473  "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
474  ddisc_sq_dalpha);
475 
476  // This is yet another alternative formula.
477  // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
478  // ddisc_sq_dalpha *= 2;
479  }
480 
481  // Newton's method formula.
482  *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
483 
484  // It's easy to take steps that are too big when we are far from the solution.
485  // So we limit the step size.
486  double stepmax = 3;
487  if (fabs(*dlogalpha)> stepmax) {
488  double sgn = *dlogalpha > 0 ? 1 : -1;
489  *dlogalpha = stepmax*sgn;
490  }
491 
492  if (*dlogalpha<0) {
493  *dlogalpha*=.5;
494  }
495 
497 }
498 
499 } // end of namespace inverse
500 } // end of namespace pism
static std::shared_ptr< TerminationReason > max_iter()
static std::shared_ptr< TerminationReason > keep_iterating()
static std::shared_ptr< TerminationReason > success()
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
T * rawptr()
Definition: Wrapper.hh:39
void copy_from(const Array2D< T > &source)
Definition: Array2D.hh:73
void add(double alpha, const Array2D< T > &x)
Definition: Array2D.hh:65
petsc::Vec & vec() const
Definition: Array.cc:339
void scale(double alpha)
Result: v <- v * alpha. Calls VecScale.
Definition: Array.cc:253
std::shared_ptr< const Grid > grid() const
Definition: Array.cc:132
void set(double c)
Result: v[j] <- c for all j.
Definition: Array.cc:707
std::shared_ptr< petsc::DM > dm() const
Definition: Array.cc:353
std::vector< double > norm(int n) const
Computes the norm of all the components of an Array.
Definition: Array.cc:746
void update_ghosts()
Updates ghost points.
Definition: Array.cc:693
Abstract base class for IPFunctionals arising from an inner product.
Definition: IPFunctional.hh:94
virtual std::shared_ptr< array::Vector > solution()
Returns the last solution of the SSA as computed by linearize_at.
virtual void apply_linearization(array::Scalar &dzeta, array::Vector &du)
Applies the linearization of the forward map (i.e. the reduced gradient described in the class-level...
virtual std::shared_ptr< TerminationReason > linearize_at(array::Scalar &zeta)
Sets the current value of the design variable and solves the SSA to find the associated .
virtual void apply_linearization_transpose(array::Vector &du, array::Scalar &dzeta)
Applies the transpose of the linearization of the forward map (i.e. the transpose of the reduced grad...
Implements the forward problem of the map taking to the corresponding solution of the SSA.
virtual std::shared_ptr< TerminationReason > linesearch()
virtual void apply_GN(array::Scalar &h, array::Scalar &out)
virtual void evaluateGNFunctional(DesignVec &h, double *value)
virtual std::shared_ptr< TerminationReason > init()
virtual std::shared_ptr< TerminationReason > evaluate_objective_and_gradient()
IPInnerProductFunctional< StateVec > & m_stateFunctional
IPInnerProductFunctional< DesignVec > & m_designFunctional
virtual std::shared_ptr< TerminationReason > compute_dlogalpha(double *dalpha)
virtual std::shared_ptr< TerminationReason > solve()
virtual std::shared_ptr< TerminationReason > check_convergence()
virtual std::shared_ptr< TerminationReason > solve_linearized()
IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward, DesignVec &d0, StateVec &u_obs, double eta, IPInnerProductFunctional< DesignVec > &designFunctional, IPInnerProductFunctional< StateVec > &stateFunctional)
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
double max(const array::Scalar &input)
Finds maximum over all the values in an array::Scalar object. Ignores ghosts.
Definition: Scalar.cc:165
bool Bool(const std::string &option, const std::string &description)
Definition: options.cc:240
std::string printf(const char *format,...)