PISM, A Parallel Ice Sheet Model 2.3.0-79cae578d committed by Constantine Khrulev on 2026-03-22
Loading...
Searching...
No Matches
SSAFD_SNES.cc
Go to the documentation of this file.
1/* Copyright (C) 2024, 2025 PISM Authors
2 *
3 * This file is part of PISM.
4 *
5 * PISM is free software; you can redistribute it and/or modify it under the
6 * terms of the GNU General Public License as published by the Free Software
7 * Foundation; either version 3 of the License, or (at your option) any later
8 * version.
9 *
10 * PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 * details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with PISM; if not, write to the Free Software
17 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19#include <algorithm> // std::max()
20
21#include "pism/stressbalance/ssa/SSAFD_SNES.hh"
22#include "pism/stressbalance/StressBalance.hh" // Inputs
23#include "pism/util/petscwrappers/Vec.hh"
24#include "pism/util/Logger.hh"
25
26namespace pism {
27namespace stressbalance {
28
29PetscErrorCode SSAFDSNESConvergenceTest(SNES snes, PetscInt it, PetscReal xnorm, PetscReal gnorm,
30 PetscReal f, SNESConvergedReason *reason, void *ctx) {
31 PetscErrorCode ierr;
32
33 SSAFD_SNES *solver = reinterpret_cast<SSAFD_SNES *>(ctx);
34 double tolerance = solver->tolerance();
35
36 ierr = SNESConvergedDefault(snes, it, xnorm, gnorm, f, reason, ctx); CHKERRQ(ierr);
37 if (*reason >= 0 and tolerance > 0) {
38 // converged or iterating
39 Vec residual;
40 ierr = SNESGetFunction(snes, &residual, NULL, NULL);
41 CHKERRQ(ierr);
42
43 PetscReal norm;
44 ierr = VecNorm(residual, NORM_INFINITY, &norm);
45 CHKERRQ(ierr);
46
47 if (norm <= tolerance) {
48 *reason = SNES_CONVERGED_FNORM_ABS;
49 }
50 }
51
52 return 0;
53}
54
55double SSAFD_SNES::tolerance() const {
56 return m_config->get_number("stress_balance.ssa.fd.absolute_tolerance");
57}
58
59SSAFD_SNES::SSAFD_SNES(std::shared_ptr<const Grid> grid, bool regional_mode)
60 : SSAFDBase(grid, regional_mode), m_residual(grid, "_ssa_residual") {
61
62 PetscErrorCode ierr;
63
64 int stencil_width=2;
65 m_DA = m_grid->get_dm(2, stencil_width);
66
67 // ierr = DMCreateGlobalVector(*m_DA, m_X.rawptr());
68 // PISM_CHK(ierr, "DMCreateGlobalVector");
69
70 ierr = SNESCreate(m_grid->com, m_snes.rawptr());
71 PISM_CHK(ierr, "SNESCreate");
72
73 // Set the SNES callbacks to call into our compute_local_function and compute_local_jacobian
76 m_callback_data.inputs = nullptr;
77
78 ierr = DMDASNESSetFunctionLocal(*m_DA, INSERT_VALUES,
79#if PETSC_VERSION_LT(3,21,0)
80 (DMDASNESFunction)SSAFD_SNES::function_callback,
81#else
82 (DMDASNESFunctionFn*)SSAFD_SNES::function_callback,
83#endif
85 PISM_CHK(ierr, "DMDASNESSetFunctionLocal");
86
87 ierr = DMDASNESSetJacobianLocal(*m_DA,
88#if PETSC_VERSION_LT(3,21,0)
89 (DMDASNESJacobian)SSAFD_SNES::jacobian_callback,
90#else
91 (DMDASNESJacobianFn*)SSAFD_SNES::jacobian_callback,
92#endif
94 PISM_CHK(ierr, "DMDASNESSetJacobianLocal");
95
96 // ierr = DMSetMatType(*m_DA, "baij");
97 // PISM_CHK(ierr, "DMSetMatType");
98
99 ierr = DMSetApplicationContext(*m_DA, &m_callback_data);
100 PISM_CHK(ierr, "DMSetApplicationContext");
101
102 ierr = SNESSetOptionsPrefix(m_snes, "ssafd_");
103 PISM_CHK(ierr, "SNESSetOptionsPrefix");
104
105 ierr = SNESSetDM(m_snes, *m_DA);
106 PISM_CHK(ierr, "SNESSetDM");
107
108 ierr = SNESSetConvergenceTest(m_snes, SSAFDSNESConvergenceTest, this, NULL);
109 PISM_CHK(ierr, "SNESSetConvergenceTest");
110
111 ierr = SNESSetTolerances(m_snes, 0.0, 0.0, 0.0, 500, -1);
112 PISM_CHK(ierr, "SNESSetTolerances");
113
114 ierr = SNESSetFromOptions(m_snes);
115 PISM_CHK(ierr, "SNESSetFromOptions");
116}
117
118void SSAFD_SNES::solve(const Inputs &inputs) {
119 m_callback_data.inputs = &inputs;
120 initialize_iterations(inputs);
121 {
122 PetscErrorCode ierr;
123
124 // Solve:
125 // ierr = SNESSolve(m_snes, NULL, m_X);
126 ierr = SNESSolve(m_snes, NULL, m_velocity_global.vec());
127 PISM_CHK(ierr, "SNESSolve");
128
129 // See if it worked.
130 SNESConvergedReason reason;
131 ierr = SNESGetConvergedReason(m_snes, &reason);
132 PISM_CHK(ierr, "SNESGetConvergedReason");
133 if (reason < 0) {
135 "SSAFD_SNES solve failed to converge (SNES reason %s)",
136 SNESConvergedReasons[reason]);
137 }
138
139 PetscInt snes_iterations = 0;
140 ierr = SNESGetIterationNumber(m_snes, &snes_iterations);
141 PISM_CHK(ierr, "SNESGetIterationNumber");
142
143 PetscInt ksp_iterations = 0;
144 ierr = SNESGetLinearSolveIterations(m_snes, &ksp_iterations);
145 PISM_CHK(ierr, "SNESGetLinearSolveIterations");
146
147 m_log->message(1, "SSA: %d*%d its, %s\n", (int)snes_iterations,
148 (int)(ksp_iterations / std::max((int)snes_iterations, 1)),
149 SNESConvergedReasons[reason]);
150 }
151 m_callback_data.inputs = nullptr;
152
153 // copy from m_velocity_global to provide m_velocity with ghosts:
155
157}
158
159
160PetscErrorCode SSAFD_SNES::function_callback(DMDALocalInfo * /*unused*/,
161 Vector2d const *const *velocity, Vector2d **result,
162 CallbackData *data) {
163 try {
164 data->solver->compute_residual(*data->inputs, velocity, result);
165 } catch (...) {
166 MPI_Comm com = MPI_COMM_SELF;
167 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)data->da, &com);
168 CHKERRQ(ierr);
170 SETERRQ(com, 1, "A PISM callback failed");
171 }
172 return 0;
173}
174
175void SSAFD_SNES::compute_jacobian(const Inputs &inputs, Vector2d const *const *const velocity,
176 Mat J) {
179}
180
181PetscErrorCode SSAFD_SNES::jacobian_callback(DMDALocalInfo * /*unused*/,
182 Vector2d const *const *const velocity, Mat /* A */,
183 Mat J, CallbackData *data) {
184 try {
185 data->solver->compute_jacobian(*data->inputs, velocity, J);
186 } catch (...) {
187 MPI_Comm com = MPI_COMM_SELF;
188 PetscErrorCode ierr = PetscObjectGetComm((PetscObject)data->da, &com);
189 CHKERRQ(ierr);
191 SETERRQ(com, 1, "A PISM callback failed");
192 }
193 return 0;
194}
195
197 return m_residual;
198}
199
200//! @brief Computes the magnitude of the driving shear stress at the base of
201//! ice (diagnostically).
202class SSAFD_residual_mag : public Diag<SSAFD_SNES> {
203public:
205
206 // set metadata:
207 m_vars = { { m_sys, "ssa_residual_mag", *m_grid } };
208
209 m_vars[0].long_name("magnitude of the SSAFD solver's residual").units("Pa");
210 }
211
212protected:
213 virtual std::shared_ptr<array::Array> compute_impl() const {
214 auto result = allocate<array::Scalar>("ssa_residual_mag");
215 result->metadata(0) = m_vars[0];
216
217 compute_magnitude(model->residual(), *result);
218
219 return result;
220 }
221};
222
225
226 result["ssa_residual"] = Diagnostic::wrap(m_residual);
227 result["ssa_residual_mag"] = Diagnostic::Ptr(new SSAFD_residual_mag(this));
228
229 return result;
230}
231
232
233} // namespace stressbalance
234} // namespace pism
std::shared_ptr< const Config > m_config
configuration database used by this component
Definition Component.hh:160
const std::shared_ptr< const Grid > m_grid
grid used by this component
Definition Component.hh:158
std::shared_ptr< const Logger > m_log
logger (for easy access)
Definition Component.hh:164
const SSAFD_SNES * model
A template derived from Diagnostic, adding a "Model".
std::vector< VariableMetadata > m_vars
metadata corresponding to NetCDF variables
static Ptr wrap(const T &input)
const units::System::Ptr m_sys
the unit system
std::shared_ptr< Diagnostic > Ptr
Definition Diagnostic.hh:67
std::shared_ptr< const Grid > m_grid
the grid
static RuntimeError formatted(const ErrorLocation &location, const char format[],...) __attribute__((format(printf
build a RuntimeError with a formatted message
This class represents a 2D vector field (such as ice velocity) at a certain grid point.
Definition Vector2d.hh:29
T * rawptr()
Definition Wrapper.hh:34
void copy_from(const Array2D< T > &source)
Definition Array2D.hh:101
petsc::Vec & vec() const
Definition Array.cc:313
const array::Scalar * basal_yield_stress
const array::Scalar * bc_mask
void fd_operator(const Geometry &geometry, const array::Scalar *bc_mask, double bc_scaling, const array::Scalar &basal_yield_stress, IceBasalResistancePlasticLaw *basal_sliding_law, const pism::Vector2d *const *velocity, const array::Staggered1 &nuH, const array::CellType1 &cell_type, Mat *A, Vector2d **Ax) const
Assemble the left-hand side matrix for the KSP-based, Picard iteration, and finite difference impleme...
Definition SSAFDBase.cc:549
void initialize_iterations(const Inputs &inputs)
array::Staggered1 m_nuH
viscosity times thickness
Definition SSAFDBase.hh:119
const double m_bc_scaling
scaling used for diagonal matrix elements at Dirichlet BC locations
Definition SSAFDBase.hh:130
void compute_residual(const Inputs &inputs, const array::Vector2 &velocity, array::Vector &result)
DiagnosticList spatial_diagnostics_impl() const
const array::Vector & residual() const
void solve(const Inputs &inputs)
std::shared_ptr< petsc::DM > m_DA
Definition SSAFD_SNES.hh:48
SSAFD_SNES(std::shared_ptr< const Grid > grid, bool regional_mode)
Definition SSAFD_SNES.cc:59
static PetscErrorCode jacobian_callback(DMDALocalInfo *info, Vector2d const *const *velocity, Mat A, Mat J, CallbackData *data)
DiagnosticList spatial_diagnostics_impl() const
static PetscErrorCode function_callback(DMDALocalInfo *info, Vector2d const *const *velocity, Vector2d **result, CallbackData *)
void compute_jacobian(const Inputs &inputs, Vector2d const *const *velocity, Mat J)
array::Vector m_residual
residual (diagnostic)
Definition SSAFD_SNES.hh:45
virtual std::shared_ptr< array::Array > compute_impl() const
Computes the magnitude of the driving shear stress at the base of ice (diagnostically).
array::Vector m_velocity_global
Definition SSA.hh:122
const array::Vector1 & velocity() const
Get the thickness-advective 2D velocity.
IceBasalResistancePlasticLaw * m_basal_sliding_law
#define PISM_CHK(errcode, name)
#define PISM_ERROR_LOCATION
PetscErrorCode SSAFDSNESConvergenceTest(SNES snes, PetscInt it, PetscReal xnorm, PetscReal gnorm, PetscReal f, SNESConvergedReason *reason, void *ctx)
Definition SSAFD_SNES.cc:29
std::map< std::string, Diagnostic::Ptr > DiagnosticList
void handle_fatal_errors(MPI_Comm com)