PISM, A Parallel Ice Sheet Model  stable v2.0.4 committed by Constantine Khrulev on 2022-05-25 12:02:27 -0800
IP_SSATaucTikhonovGNSolver.cc
Go to the documentation of this file.
1 // Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021 David Maxwell and Constantine Khroulev
2 //
3 // This file is part of PISM.
4 //
5 // PISM is free software; you can redistribute it and/or modify it under the
6 // terms of the GNU General Public License as published by the Free Software
7 // Foundation; either version 3 of the License, or (at your option) any later
8 // version.
9 //
10 // PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 // WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 // details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with PISM; if not, write to the Free Software
17 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 
20 #include "pism/util/TerminationReason.hh"
21 #include "pism/util/pism_options.hh"
22 #include "pism/util/ConfigInterface.hh"
23 #include "pism/util/IceGrid.hh"
24 #include "pism/util/Context.hh"
25 #include "pism/util/petscwrappers/Vec.hh"
26 
27 namespace pism {
28 namespace inverse {
29 
31  DesignVec &d0, StateVec &u_obs, double eta,
32  IPInnerProductFunctional<DesignVec> &designFunctional,
33  IPInnerProductFunctional<StateVec> &stateFunctional)
34  : m_design_stencil_width(d0.stencil_width()),
35  m_state_stencil_width(u_obs.stencil_width()),
36  m_ssaforward(ssaforward),
37  m_x(d0.grid(), "x", WITH_GHOSTS, m_design_stencil_width),
38  m_tmp_D1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
39  m_tmp_D2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
40  m_tmp_D1Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
41  m_tmp_D2Local(d0.grid(), "work vector", WITH_GHOSTS, m_design_stencil_width),
42  m_tmp_S1Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
43  m_tmp_S2Global(d0.grid(), "work vector", WITHOUT_GHOSTS, 0),
44  m_tmp_S1Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
45  m_tmp_S2Local(d0.grid(), "work vector", WITH_GHOSTS, m_state_stencil_width),
46  m_GN_rhs(d0.grid(), "GN_rhs", WITHOUT_GHOSTS, 0),
47  m_d0(d0),
48  m_dGlobal(d0.grid(), "d (sans ghosts)", WITHOUT_GHOSTS, 0),
49  m_d_diff(d0.grid(), "d_diff", WITH_GHOSTS, m_design_stencil_width),
50  m_d_diff_lin(d0.grid(), "d_diff linearized", WITH_GHOSTS, m_design_stencil_width),
51  m_h(d0.grid(), "h", WITH_GHOSTS, m_design_stencil_width),
52  m_hGlobal(d0.grid(), "h (sans ghosts)", WITHOUT_GHOSTS),
53  m_dalpha_rhs(d0.grid(), "dalpha rhs", WITHOUT_GHOSTS),
54  m_dh_dalpha(d0.grid(), "dh_dalpha", WITH_GHOSTS, m_design_stencil_width),
55  m_dh_dalphaGlobal(d0.grid(), "dh_dalpha", WITHOUT_GHOSTS),
56  m_grad_design(d0.grid(), "grad design", WITHOUT_GHOSTS),
57  m_grad_state(d0.grid(), "grad design", WITHOUT_GHOSTS),
58  m_gradient(d0.grid(), "grad design", WITHOUT_GHOSTS),
59  m_u_obs(u_obs),
60  m_u_diff(d0.grid(), "du", WITH_GHOSTS, m_state_stencil_width),
61  m_eta(eta),
62  m_designFunctional(designFunctional),
63  m_stateFunctional(stateFunctional),
64  m_target_misfit(0.0)
65 {
66  PetscErrorCode ierr;
67  IceGrid::ConstPtr grid = m_d0.grid();
68  m_comm = grid->com;
69 
70  m_d.reset(new DesignVec(grid, "d", WITH_GHOSTS, m_design_stencil_width));
71 
72  ierr = KSPCreate(grid->com, m_ksp.rawptr());
73  PISM_CHK(ierr, "KSPCreate");
74 
75  ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
76  PISM_CHK(ierr, "KSPSetOptionsPrefix");
77 
78  double ksp_rtol = 1e-5; // Soft tolerance
79  ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
80  PISM_CHK(ierr, "KSPSetTolerances");
81 
82  ierr = KSPSetType(m_ksp, KSPCG);
83  PISM_CHK(ierr, "KSPSetType");
84 
85  PC pc;
86  ierr = KSPGetPC(m_ksp, &pc);
87  PISM_CHK(ierr, "KSPGetPC");
88 
89  ierr = PCSetType(pc, PCNONE);
90  PISM_CHK(ierr, "PCSetType");
91 
92  ierr = KSPSetFromOptions(m_ksp);
93  PISM_CHK(ierr, "KSPSetFromOptions");
94 
95  int nLocalNodes = grid->xm()*grid->ym();
96  int nGlobalNodes = grid->Mx()*grid->My();
97  ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
98  nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
99  PISM_CHK(ierr, "MatCreateShell");
100 
103  multCallback::connect(m_mat_GN);
104 
105  m_alpha = 1./m_eta;
106  m_logalpha = log(m_alpha);
107 
108  m_tikhonov_adaptive = options::Bool("-tikhonov_adaptive", "Tikhonov adaptive");
109 
110  m_iter_max = 1000;
111  m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
112 
113  m_tikhonov_atol = grid->ctx()->config()->get_number("inverse.tikhonov.atol");
114  m_tikhonov_rtol = grid->ctx()->config()->get_number("inverse.tikhonov.rtol");
115  m_tikhonov_ptol = grid->ctx()->config()->get_number("inverse.tikhonov.ptol");
116 
117  m_log = d0.grid()->ctx()->log();
118 }
119 
122 }
123 
125  this->apply_GN(x.vec(), y.vec());
126 }
127 
128 //! @note This function has to return PetscErrorCode (it is used as a callback).
130  StateVec &tmp_gS = m_tmp_S1Global;
131  StateVec &Tx = m_tmp_S1Local;
132  DesignVec &tmp_gD = m_tmp_D1Global;
133  DesignVec &GNx = m_tmp_D2Global;
134 
135  PetscErrorCode ierr;
136  // FIXME: Needless copies for now.
137  {
138  ierr = DMGlobalToLocalBegin(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
139  PISM_CHK(ierr, "DMGlobalToLocalBegin");
140 
141  ierr = DMGlobalToLocalEnd(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
142  PISM_CHK(ierr, "DMGlobalToLocalEnd");
143  }
144 
146  Tx.update_ghosts();
147 
148  m_stateFunctional.interior_product(Tx,tmp_gS);
149 
151 
152  m_designFunctional.interior_product(m_x,tmp_gD);
153  GNx.add(m_alpha,tmp_gD);
154 
155  ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
156 }
157 
159 
160  rhs.set(0);
161 
162  m_stateFunctional.interior_product(m_u_diff,m_tmp_S1Global);
164 
165  m_designFunctional.interior_product(m_d_diff,m_tmp_D1Global);
167 
168  rhs.scale(-1);
169 }
170 
172  PetscErrorCode ierr;
173 
174  this->assemble_GN_rhs(m_GN_rhs);
175 
176  ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
177  PISM_CHK(ierr, "KSPSetOperators");
178 
179  ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
180  PISM_CHK(ierr, "KSPSolve");
181 
182  KSPConvergedReason ksp_reason;
183  ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
184  PISM_CHK(ierr, "KSPGetConvergedReason");
185 
187 
188  return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
189 }
190 
192 
196 
197  double sValue;
198  m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
199 
201  m_tmp_D1Local.add(1,h);
202 
203  double dValue;
204  m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
205 
206  *value = m_alpha*dValue + sValue;
207 }
208 
209 
211 
212  double designNorm, stateNorm, sumNorm;
213  double dWeight, sWeight;
214  dWeight = m_alpha;
215  sWeight = 1;
216 
217  designNorm = m_grad_design.norm(NORM_2)[0];
218  stateNorm = m_grad_state.norm(NORM_2)[0];
219 
220  designNorm *= dWeight;
221  stateNorm *= sWeight;
222 
223  sumNorm = m_gradient.norm(NORM_2)[0];
224 
225  m_log->message(2,
226  "----------------------------------------------------------\n");
227  m_log->message(2,
228  "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
230  if (m_tikhonov_adaptive) {
231  m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
232  }
233  double relsum = (sumNorm/std::max(designNorm,stateNorm));
234  m_log->message(2,
235  "design norm %g stateNorm %g sum %g; relative difference %g\n",
236  designNorm, stateNorm, sumNorm, relsum);
237 
238  // If we have an adaptive tikhonov parameter, check if we have met
239  // this constraint first.
240  if (m_tikhonov_adaptive) {
241  double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
242  if (disc_ratio > m_tikhonov_ptol) {
244  }
245  }
246 
247  if (sumNorm < m_tikhonov_atol) {
248  return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
249  }
250 
251  if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
252  return TerminationReason::Ptr(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
253  }
254 
255  if (m_iter>m_iter_max) {
257  } else {
259  }
260 }
261 
263 
265  if (reason->failed()) {
266  return reason;
267  }
268 
270  m_d_diff.add(-1,m_d0);
271 
273  m_u_diff.add(-1,m_u_obs);
274 
276 
277  // The following computes the reduced gradient.
278  StateVec &adjointRHS = m_tmp_S1Global;
279  m_stateFunctional.gradientAt(m_u_diff,adjointRHS);
281 
285 
286  double valDesign, valState;
287  m_designFunctional.valueAt(m_d_diff,&valDesign);
288  m_stateFunctional.valueAt(m_u_diff,&valState);
289 
290  m_val_design = valDesign;
291  m_val_state = valState;
292 
293  m_value = valDesign * m_alpha + valState;
294 
295  return reason;
296 }
297 
299  PetscErrorCode ierr;
300 
301  TerminationReason::Ptr step_reason;
302 
303  double old_value = m_val_design * m_alpha + m_val_state;
304 
305  double descent_derivative;
306 
308 
309  ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
310  PISM_CHK(ierr, "VecDot");
311 
312  if (descent_derivative >=0) {
313  printf("descent derivative: %g\n",descent_derivative);
314  return TerminationReason::Ptr(new GenericTerminationReason(-1, "Not descent direction"));
315  }
316 
317  double alpha = 1;
319  while(true) {
320  m_d->add(alpha,m_h); // Replace with line search.
321  step_reason = this->evaluate_objective_and_gradient();
322  if (step_reason->succeeded()) {
323  if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
324  break;
325  }
326  }
327  else {
328  printf("forward solve failed in linsearch. Shrinking.\n");
329  }
330  alpha *=.5;
331  if (alpha<1e-20) {
332  printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
333  return TerminationReason::Ptr(new GenericTerminationReason(-1, "Too many step shrinks."));
334  }
335  m_d->copy_from(m_tmp_D1Local);
336  }
337 
339 }
340 
342 
343  if (m_target_misfit == 0) {
344  throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
345  " IP_SSATaucTikhonovGNSolver::solve.");
346  }
347 
348  m_iter = 0;
349  m_d->copy_from(m_d0);
350 
351  double dlogalpha = 0;
352 
353  TerminationReason::Ptr step_reason, reason;
354 
355  step_reason = this->evaluate_objective_and_gradient();
356  if (step_reason->failed()) {
357  reason.reset(new GenericTerminationReason(-1,"Forward solve"));
358  reason->set_root_cause(step_reason);
359  return reason;
360  }
361 
362  while(true) {
363 
364  reason = this->check_convergence();
365  if (reason->done()) {
366  return reason;
367  }
368 
369  if (m_tikhonov_adaptive) {
370  m_logalpha += dlogalpha;
371  m_alpha = exp(m_logalpha);
372  }
373 
374  step_reason = this->solve_linearized();
375  if (step_reason->failed()) {
376  reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
377  reason->set_root_cause(step_reason);
378  return reason;
379  }
380 
381  step_reason = this->linesearch();
382  if (step_reason->failed()) {
383  TerminationReason::Ptr cause = reason;
384  reason.reset(new GenericTerminationReason(-1,"Linesearch"));
385  reason->set_root_cause(step_reason);
386  return reason;
387  }
388 
389  if (m_tikhonov_adaptive) {
390  step_reason = this->compute_dlogalpha(&dlogalpha);
391  if (step_reason->failed()) {
392  TerminationReason::Ptr cause = reason;
393  reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
394  reason->set_root_cause(step_reason);
395  return reason;
396  }
397  }
398 
399  m_iter++;
400  }
401 
402  return reason;
403 }
404 
406 
407  PetscErrorCode ierr;
408 
409  // Compute the right-hand side for computing dh/dalpha.
411  m_d_diff_lin.add(1,m_h);
412  m_designFunctional.interior_product(m_d_diff_lin,m_dalpha_rhs);
413  m_dalpha_rhs.scale(-1);
414 
415  // Solve linear equation for dh/dalpha.
416  ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
417  PISM_CHK(ierr, "KSPSetOperators");
418 
419  ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
420  PISM_CHK(ierr, "KSPSolve");
421 
423 
424  KSPConvergedReason ksp_reason;
425  ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
426  PISM_CHK(ierr, "KSPGetConvergedReason");
427 
428  if (ksp_reason<0) {
429  return TerminationReason::Ptr(new KSPTerminationReason(ksp_reason));
430  }
431 
432  // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
436 
437  // Compute linearized discrepancy.
438  double disc_sq;
440 
441  // There are a number of equivalent ways to compute the derivative of the
442  // linearized discrepancy with respect to alpha, some of which are cheaper
443  // than others to compute. This equivalency relies, however, on having an
444  // exact solution in the Gauss-Newton step. Since we only solve this with
445  // a soft tolerance, we lose equivalency. We attempt a cheap computation,
446  // and then do a sanity check (namely that the derivative is positive).
447  // If this fails, we compute by a harder way that inherently yields a
448  // positive number.
449 
450  double ddisc_sq_dalpha;
451  m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
452  ddisc_sq_dalpha *= -2*m_alpha;
453 
454  if (ddisc_sq_dalpha <= 0) {
455  // Try harder.
456 
457  m_log->message(3,
458  "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
459  " Tighten inv_gn_ksp_rtol?\n",
460  ddisc_sq_dalpha);
461 
462  // S2Local contains T(dh/dalpha)
465 
466  double ddisc_sq_dalpha_a;
467  m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
468  double ddisc_sq_dalpha_b;
469  m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
470  ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
471 
472  m_log->message(3,
473  "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
474  ddisc_sq_dalpha);
475 
476  // This is yet another alternative formula.
477  // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
478  // ddisc_sq_dalpha *= 2;
479  }
480 
481  // Newton's method formula.
482  *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
483 
484  // It's easy to take steps that are too big when we are far from the solution.
485  // So we limit the step size.
486  double stepmax = 3;
487  if (fabs(*dlogalpha)> stepmax) {
488  double sgn = *dlogalpha > 0 ? 1 : -1;
489  *dlogalpha = stepmax*sgn;
490  }
491 
492  if (*dlogalpha<0) {
493  *dlogalpha*=.5;
494  }
495 
497 }
498 
499 } // end of namespace inverse
500 } // end of namespace pism
static TerminationReason::Ptr keep_iterating()
static TerminationReason::Ptr success()
static TerminationReason::Ptr max_iter()
std::shared_ptr< const IceGrid > ConstPtr
Definition: IceGrid.hh:233
void add(double alpha, const IceModelVec2S &x)
void copy_from(const IceModelVec2S &source)
void copy_from(const IceModelVec2< T > &source)
Definition: IceModelVec2.hh:97
void add(double alpha, const IceModelVec2< T > &x)
Definition: IceModelVec2.hh:89
void update_ghosts()
Updates ghost points.
Definition: iceModelVec.cc:669
petsc::Vec & vec() const
Definition: iceModelVec.cc:342
std::shared_ptr< petsc::DM > dm() const
Definition: iceModelVec.cc:356
void set(double c)
Result: v[j] <- c for all j.
Definition: iceModelVec.cc:683
void scale(double alpha)
Result: v <- v * alpha. Calls VecScale.
Definition: iceModelVec.cc:252
IceGrid::ConstPtr grid() const
Definition: iceModelVec.cc:128
std::vector< double > norm(int n) const
Computes the norm of all the components of an IceModelVec.
Definition: iceModelVec.cc:769
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
std::shared_ptr< TerminationReason > Ptr
Abstract base class for IPFunctionals arising from an inner product.
Definition: IPFunctional.hh:93
virtual IceModelVec2V::Ptr solution()
Returns the last solution of the SSA as computed by linearize_at.
virtual void apply_linearization_transpose(IceModelVec2V &du, IceModelVec2S &dzeta)
Applies the transpose of the linearization of the forward map (i.e. the transpose of the reduced grad...
virtual void apply_linearization(IceModelVec2S &dzeta, IceModelVec2V &du)
Applies the linearization of the forward map (i.e. the reduced gradient described in the class-level...
virtual TerminationReason::Ptr linearize_at(IceModelVec2S &zeta)
Sets the current value of the design variable and solves the SSA to find the associated .
Implements the forward problem of the map taking to the corresponding solution of the SSA.
virtual void evaluateGNFunctional(DesignVec &h, double *value)
virtual TerminationReason::Ptr check_convergence()
IPInnerProductFunctional< StateVec > & m_stateFunctional
virtual TerminationReason::Ptr evaluate_objective_and_gradient()
IPInnerProductFunctional< DesignVec > & m_designFunctional
virtual TerminationReason::Ptr compute_dlogalpha(double *dalpha)
virtual void apply_GN(IceModelVec2S &h, IceModelVec2S &out)
IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward, DesignVec &d0, StateVec &u_obs, double eta, IPInnerProductFunctional< DesignVec > &designFunctional, IPInnerProductFunctional< StateVec > &stateFunctional)
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
bool Bool(const std::string &option, const std::string &description)
Definition: options.cc:240
double max(const IceModelVec2S &input)
Finds maximum over all the values in an IceModelVec2S object. Ignores ghosts.
std::string printf(const char *format,...)
@ WITHOUT_GHOSTS
Definition: iceModelVec.hh:49
@ WITH_GHOSTS
Definition: iceModelVec.hh:49