PISM, A Parallel Ice Sheet Model 2.3.0-79cae578d committed by Constantine Khrulev on 2026-03-22
Loading...
Searching...
No Matches
IP_SSATaucTikhonovGNSolver.cc
Go to the documentation of this file.
1// Copyright (C) 2012--2025 David Maxwell and Constantine Khroulev
2//
3// This file is part of PISM.
4//
5// PISM is free software; you can redistribute it and/or modify it under the
6// terms of the GNU General Public License as published by the Free Software
7// Foundation; either version 3 of the License, or (at your option) any later
8// version.
9//
10// PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13// details.
14//
15// You should have received a copy of the GNU General Public License
16// along with PISM; if not, write to the Free Software
17// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18
19#include "pism/inverse/IP_SSATaucTikhonovGNSolver.hh"
20#include "pism/util/TerminationReason.hh"
21#include "pism/util/pism_options.hh"
22#include "pism/util/Config.hh"
23#include "pism/util/Grid.hh"
24#include "pism/util/Context.hh"
25#include "pism/util/petscwrappers/Vec.hh"
26#include "pism/util/petscwrappers/DM.hh"
27#include "pism/util/Logger.hh"
28
29namespace pism {
30namespace inverse {
31
33 DesignVec &d0, StateVec &u_obs, double eta,
36 : m_design_stencil_width(d0.stencil_width()),
37 m_state_stencil_width(u_obs.stencil_width()),
38 m_ssaforward(ssaforward),
39 m_x(d0.grid(), "x"),
40 m_tmp_D1Global(d0.grid(), "work vector"),
41 m_tmp_D2Global(d0.grid(), "work vector"),
42 m_tmp_D1Local(d0.grid(), "work vector"),
43 m_tmp_D2Local(d0.grid(), "work vector"),
44 m_tmp_S1Global(d0.grid(), "work vector"),
45 m_tmp_S2Global(d0.grid(), "work vector"),
46 m_tmp_S1Local(d0.grid(), "work vector"),
47 m_tmp_S2Local(d0.grid(), "work vector"),
48 m_GN_rhs(d0.grid(), "GN_rhs"),
49 m_d0(d0),
50 m_dGlobal(d0.grid(), "d (sans ghosts)"),
51 m_d_diff(d0.grid(), "d_diff"),
52 m_d_diff_lin(d0.grid(), "d_diff linearized"),
53 m_h(d0.grid(), "h"),
54 m_hGlobal(d0.grid(), "h (sans ghosts)"),
55 m_dalpha_rhs(d0.grid(), "dalpha rhs"),
56 m_dh_dalpha(d0.grid(), "dh_dalpha"),
57 m_dh_dalphaGlobal(d0.grid(), "dh_dalpha"),
58 m_grad_design(d0.grid(), "grad design"),
59 m_grad_state(d0.grid(), "grad design"),
60 m_gradient(d0.grid(), "grad design"),
61 m_u_obs(u_obs),
62 m_u_diff(d0.grid(), "du"),
63 m_eta(eta),
64 m_designFunctional(designFunctional),
65 m_stateFunctional(stateFunctional),
66 m_target_misfit(0.0)
67{
68 PetscErrorCode ierr;
69 std::shared_ptr<const Grid> grid = m_d0.grid();
70 m_comm = grid->com;
71
72 m_d = std::make_shared<DesignVecGhosted>(grid, "d");
73
74 ierr = KSPCreate(grid->com, m_ksp.rawptr());
75 PISM_CHK(ierr, "KSPCreate");
76
77 ierr = KSPSetOptionsPrefix(m_ksp, "inv_gn_");
78 PISM_CHK(ierr, "KSPSetOptionsPrefix");
79
80 double ksp_rtol = 1e-5; // Soft tolerance
81 ierr = KSPSetTolerances(m_ksp, ksp_rtol, PETSC_DEFAULT, PETSC_DEFAULT, PETSC_DEFAULT);
82 PISM_CHK(ierr, "KSPSetTolerances");
83
84 ierr = KSPSetType(m_ksp, KSPCG);
85 PISM_CHK(ierr, "KSPSetType");
86
87 PC pc;
88 ierr = KSPGetPC(m_ksp, &pc);
89 PISM_CHK(ierr, "KSPGetPC");
90
91 ierr = PCSetType(pc, PCNONE);
92 PISM_CHK(ierr, "PCSetType");
93
94 ierr = KSPSetFromOptions(m_ksp);
95 PISM_CHK(ierr, "KSPSetFromOptions");
96
97 int nLocalNodes = grid->xm()*grid->ym();
98 int nGlobalNodes = grid->Mx()*grid->My();
99 ierr = MatCreateShell(grid->com, nLocalNodes, nLocalNodes,
100 nGlobalNodes, nGlobalNodes, this, m_mat_GN.rawptr());
101 PISM_CHK(ierr, "MatCreateShell");
102
105 multCallback::connect(m_mat_GN);
106
107 m_alpha = 1./m_eta;
108 m_logalpha = log(m_alpha);
109
110 m_iter_max = 1000;
111 m_iter_max = options::Integer("-inv_gn_iter_max", "", m_iter_max);
112
113 auto config = grid->ctx()->config();
114
115 m_tikhonov_adaptive = config->get_flag("inverse.tikhonov.adaptive");
116 m_tikhonov_atol = config->get_number("inverse.tikhonov.atol");
117 m_tikhonov_rtol = config->get_number("inverse.tikhonov.rtol");
118 m_tikhonov_ptol = config->get_number("inverse.tikhonov.ptol");
119
120 m_log = d0.grid()->ctx()->log();
121}
122
123std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::init() {
125}
126
130
131//! @note This function has to return PetscErrorCode (it is used as a callback).
133 StateVec &tmp_gS = m_tmp_S1Global;
135 DesignVec &tmp_gD = m_tmp_D1Global;
137
138 PetscErrorCode ierr;
139 // FIXME: Needless copies for now.
140 {
141 ierr = DMGlobalToLocalBegin(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
142 PISM_CHK(ierr, "DMGlobalToLocalBegin");
143
144 ierr = DMGlobalToLocalEnd(*m_x.dm(), x, INSERT_VALUES, m_x.vec());
145 PISM_CHK(ierr, "DMGlobalToLocalEnd");
146 }
147
149 Tx.update_ghosts();
150
151 m_stateFunctional.interior_product(Tx,tmp_gS);
152
154
155 m_designFunctional.interior_product(m_x,tmp_gD);
156 GNx.add(m_alpha,tmp_gD);
157
158 ierr = VecCopy(GNx.vec(), y); PISM_CHK(ierr, "VecCopy");
159}
160
173
174std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve_linearized() {
175 PetscErrorCode ierr;
176
178
179 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
180 PISM_CHK(ierr, "KSPSetOperators");
181
182 ierr = KSPSolve(m_ksp,m_GN_rhs.vec(),m_hGlobal.vec());
183 PISM_CHK(ierr, "KSPSolve");
184
185 KSPConvergedReason ksp_reason;
186 ierr = KSPGetConvergedReason(m_ksp ,&ksp_reason);
187 PISM_CHK(ierr, "KSPGetConvergedReason");
188
190
191 return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
192}
193
195
199
200 double sValue;
201 m_stateFunctional.valueAt(m_tmp_S1Local,&sValue);
202
204 m_tmp_D1Local.add(1,h);
205
206 double dValue;
207 m_designFunctional.valueAt(m_tmp_D1Local,&dValue);
208
209 *value = m_alpha*dValue + sValue;
210}
211
212
213std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::check_convergence() {
214
215 double designNorm, stateNorm, sumNorm;
216 double dWeight, sWeight;
217 dWeight = m_alpha;
218 sWeight = 1;
219
220 designNorm = m_grad_design.norm(NORM_2)[0];
221 stateNorm = m_grad_state.norm(NORM_2)[0];
222
223 designNorm *= dWeight;
224 stateNorm *= sWeight;
225
226 sumNorm = m_gradient.norm(NORM_2)[0];
227
228 m_log->message(2,
229 "----------------------------------------------------------\n");
230 m_log->message(2,
231 "IP_SSATaucTikhonovGNSolver Iteration %d: misfit %g; functional %g \n",
234 m_log->message(2, "alpha %g; log(alpha) %g\n", m_alpha, m_logalpha);
235 }
236 double relsum = (sumNorm/std::max(designNorm,stateNorm));
237 m_log->message(2,
238 "design norm %g stateNorm %g sum %g; relative difference %g\n",
239 designNorm, stateNorm, sumNorm, relsum);
240
241 // If we have an adaptive tikhonov parameter, check if we have met
242 // this constraint first.
244 double disc_ratio = fabs((sqrt(m_val_state)/m_target_misfit) - 1.);
245 if (disc_ratio > m_tikhonov_ptol) {
247 }
248 }
249
250 if (sumNorm < m_tikhonov_atol) {
251 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_ATOL"));
252 }
253
254 if (sumNorm < m_tikhonov_rtol*std::max(designNorm,stateNorm)) {
255 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(1,"TIKHONOV_RTOL"));
256 }
257
258 if (m_iter > m_iter_max) {
260 }
261
263}
264
266
267 std::shared_ptr<TerminationReason> reason = m_ssaforward.linearize_at(*m_d);
268 if (reason->failed()) {
269 return reason;
270 }
271
273 m_d_diff.add(-1,m_d0);
274
276 m_u_diff.add(-1,m_u_obs);
277
279
280 // The following computes the reduced gradient.
281 StateVec &adjointRHS = m_tmp_S1Global;
282 m_stateFunctional.gradientAt(m_u_diff,adjointRHS);
284
288
289 double valDesign, valState;
290 m_designFunctional.valueAt(m_d_diff,&valDesign);
291 m_stateFunctional.valueAt(m_u_diff,&valState);
292
293 m_val_design = valDesign;
294 m_val_state = valState;
295
296 m_value = valDesign * m_alpha + valState;
297
298 return reason;
299}
300
301std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::linesearch() {
302 PetscErrorCode ierr;
303
304 std::shared_ptr<TerminationReason> step_reason;
305
306 double old_value = m_val_design * m_alpha + m_val_state;
307
308 double descent_derivative;
309
311
312 ierr = VecDot(m_gradient.vec(), m_tmp_D1Global.vec(), &descent_derivative);
313 PISM_CHK(ierr, "VecDot");
314
315 if (descent_derivative >=0) {
316 printf("descent derivative: %g\n",descent_derivative);
317 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Not descent direction"));
318 }
319
320 double alpha = 1;
322 while(true) {
323 m_d->add(alpha,m_h); // Replace with line search.
324 step_reason = this->evaluate_objective_and_gradient();
325 if (step_reason->succeeded()) {
326 if (m_value <= old_value + 1e-3*alpha*descent_derivative) {
327 break;
328 }
329 }
330 else {
331 printf("forward solve failed in linsearch. Shrinking.\n");
332 }
333 alpha *=.5;
334 if (alpha<1e-20) {
335 printf("alpha= %g; derivative = %g\n",alpha,descent_derivative);
336 return std::shared_ptr<TerminationReason>(new GenericTerminationReason(-1, "Too many step shrinks."));
337 }
338 m_d->copy_from(m_tmp_D1Local);
339 }
340
342}
343
344std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::solve() {
345
346 if (m_target_misfit == 0) {
347 throw RuntimeError::formatted(PISM_ERROR_LOCATION, "Call set target misfit prior to calling"
348 " IP_SSATaucTikhonovGNSolver::solve.");
349 }
350
351 m_iter = 0;
352 m_d->copy_from(m_d0);
353
354 double dlogalpha = 0;
355
356 std::shared_ptr<TerminationReason> step_reason, reason;
357
358 step_reason = this->evaluate_objective_and_gradient();
359 if (step_reason->failed()) {
360 reason.reset(new GenericTerminationReason(-1,"Forward solve"));
361 reason->set_root_cause(step_reason);
362 return reason;
363 }
364
365 while(true) {
366
367 reason = this->check_convergence();
368 if (reason->done()) {
369 return reason;
370 }
371
373 m_logalpha += dlogalpha;
374 m_alpha = exp(m_logalpha);
375 }
376
377 step_reason = this->solve_linearized();
378 if (step_reason->failed()) {
379 reason.reset(new GenericTerminationReason(-1,"Gauss Newton solve"));
380 reason->set_root_cause(step_reason);
381 return reason;
382 }
383
384 step_reason = this->linesearch();
385 if (step_reason->failed()) {
386 std::shared_ptr<TerminationReason> cause = reason;
387 reason.reset(new GenericTerminationReason(-1,"Linesearch"));
388 reason->set_root_cause(step_reason);
389 return reason;
390 }
391
393 step_reason = this->compute_dlogalpha(&dlogalpha);
394 if (step_reason->failed()) {
395 std::shared_ptr<TerminationReason> cause = reason;
396 reason.reset(new GenericTerminationReason(-1,"Tikhonov penalty update"));
397 reason->set_root_cause(step_reason);
398 return reason;
399 }
400 }
401
402 m_iter++;
403 }
404
405 return reason;
406}
407
408std::shared_ptr<TerminationReason> IP_SSATaucTikhonovGNSolver::compute_dlogalpha(double *dlogalpha) {
409
410 PetscErrorCode ierr;
411
412 // Compute the right-hand side for computing dh/dalpha.
414 m_d_diff_lin.add(1,m_h);
417
418 // Solve linear equation for dh/dalpha.
419 ierr = KSPSetOperators(m_ksp,m_mat_GN,m_mat_GN);
420 PISM_CHK(ierr, "KSPSetOperators");
421
422 ierr = KSPSolve(m_ksp,m_dalpha_rhs.vec(),m_dh_dalphaGlobal.vec());
423 PISM_CHK(ierr, "KSPSolve");
424
426
427 KSPConvergedReason ksp_reason;
428 ierr = KSPGetConvergedReason(m_ksp,&ksp_reason);
429 PISM_CHK(ierr, "KSPGetConvergedReason");
430
431 if (ksp_reason<0) {
432 return std::shared_ptr<TerminationReason>(new KSPTerminationReason(ksp_reason));
433 }
434
435 // S1Local contains T(h) + F(x) - u_obs, i.e. the linearized misfit field.
439
440 // Compute linearized discrepancy.
441 double disc_sq;
443
444 // There are a number of equivalent ways to compute the derivative of the
445 // linearized discrepancy with respect to alpha, some of which are cheaper
446 // than others to compute. This equivalency relies, however, on having an
447 // exact solution in the Gauss-Newton step. Since we only solve this with
448 // a soft tolerance, we lose equivalency. We attempt a cheap computation,
449 // and then do a sanity check (namely that the derivative is positive).
450 // If this fails, we compute by a harder way that inherently yields a
451 // positive number.
452
453 double ddisc_sq_dalpha;
454 m_designFunctional.dot(m_dh_dalpha,m_d_diff_lin,&ddisc_sq_dalpha);
455 ddisc_sq_dalpha *= -2*m_alpha;
456
457 if (ddisc_sq_dalpha <= 0) {
458 // Try harder.
459
460 m_log->message(3,
461 "Adaptive Tikhonov sanity check failed (dh/dalpha= %g <= 0)."
462 " Tighten inv_gn_ksp_rtol?\n",
463 ddisc_sq_dalpha);
464
465 // S2Local contains T(dh/dalpha)
468
469 double ddisc_sq_dalpha_a;
470 m_stateFunctional.dot(m_tmp_S2Local,m_tmp_S2Local,&ddisc_sq_dalpha_a);
471 double ddisc_sq_dalpha_b;
472 m_designFunctional.dot(m_dh_dalpha,m_dh_dalpha,&ddisc_sq_dalpha_b);
473 ddisc_sq_dalpha = 2*m_alpha*(ddisc_sq_dalpha_a+m_alpha*ddisc_sq_dalpha_b);
474
475 m_log->message(3,
476 "Adaptive Tikhonov sanity check recovery attempt: dh/dalpha= %g. \n",
477 ddisc_sq_dalpha);
478
479 // This is yet another alternative formula.
480 // m_stateFunctional.dot(m_tmp_S1Local,m_tmp_S2Local,&ddisc_sq_dalpha);
481 // ddisc_sq_dalpha *= 2;
482 }
483
484 // Newton's method formula.
485 *dlogalpha = (m_target_misfit*m_target_misfit-disc_sq)/(ddisc_sq_dalpha*m_alpha);
486
487 // It's easy to take steps that are too big when we are far from the solution.
488 // So we limit the step size.
489 double stepmax = 3;
490 if (fabs(*dlogalpha)> stepmax) {
491 double sgn = *dlogalpha > 0 ? 1 : -1;
492 *dlogalpha = stepmax*sgn;
493 }
494
495 if (*dlogalpha<0) {
496 *dlogalpha*=.5;
497 }
498
500}
501
502} // end of namespace inverse
503} // end of namespace pism
static std::shared_ptr< TerminationReason > max_iter()
static std::shared_ptr< TerminationReason > keep_iterating()
static std::shared_ptr< TerminationReason > success()
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
T * rawptr()
Definition Wrapper.hh:34
void copy_from(const Array2D< T > &source)
Definition Array2D.hh:101
void add(double alpha, const Array2D< T > &x)
Definition Array2D.hh:93
petsc::Vec & vec() const
Definition Array.cc:313
void scale(double alpha)
Result: v <- v * alpha. Calls VecScale.
Definition Array.cc:227
std::shared_ptr< const Grid > grid() const
Definition Array.cc:134
void set(double c)
Result: v[j] <- c for all j.
Definition Array.cc:659
std::shared_ptr< petsc::DM > dm() const
Definition Array.cc:327
std::vector< double > norm(int n) const
Computes the norm of all the components of an Array.
Definition Array.cc:699
void update_ghosts()
Updates ghost points.
Definition Array.cc:645
Abstract base class for IPFunctionals arising from an inner product.
virtual std::shared_ptr< array::Vector > solution()
Returns the last solution of the SSA as computed by linearize_at.
virtual void apply_linearization(array::Scalar &dzeta, array::Vector &du)
Applies the linearization of the forward map (i.e. the reduced gradient described in the class-level...
virtual std::shared_ptr< TerminationReason > linearize_at(array::Scalar &zeta)
Sets the current value of the design variable and solves the SSA to find the associated .
virtual void apply_linearization_transpose(array::Vector &du, array::Scalar &dzeta)
Applies the transpose of the linearization of the forward map (i.e. the transpose of the reduced grad...
Implements the forward problem of the map taking to the corresponding solution of the SSA.
virtual std::shared_ptr< TerminationReason > linesearch()
virtual void apply_GN(array::Scalar &h, array::Scalar &out)
virtual void evaluateGNFunctional(DesignVec &h, double *value)
virtual std::shared_ptr< TerminationReason > init()
virtual std::shared_ptr< TerminationReason > evaluate_objective_and_gradient()
IPInnerProductFunctional< StateVec > & m_stateFunctional
IPInnerProductFunctional< DesignVec > & m_designFunctional
virtual std::shared_ptr< TerminationReason > compute_dlogalpha(double *dalpha)
virtual std::shared_ptr< TerminationReason > solve()
virtual std::shared_ptr< TerminationReason > check_convergence()
virtual std::shared_ptr< TerminationReason > solve_linearized()
IP_SSATaucTikhonovGNSolver(IP_SSATaucForwardProblem &ssaforward, DesignVec &d0, StateVec &u_obs, double eta, IPInnerProductFunctional< DesignVec > &designFunctional, IPInnerProductFunctional< StateVec > &stateFunctional)
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
std::string printf(const char *format,...)