-
Notifications
You must be signed in to change notification settings - Fork 101
Avoid reallocation for self residual norm calculation #1898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
| // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors | ||
| // | ||
| // SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
|
|
@@ -86,6 +86,8 @@ class ResidualNormBase | |
| std::shared_ptr<const Vector> neg_one_{}; | ||
| // workspace for reduction | ||
| mutable gko::array<char> reduction_tmp_; | ||
| // temporary rhs for residual computation | ||
| mutable std::shared_ptr<LinOp> rhs_{}; | ||
|
Comment on lines
88
to
+90
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should avoid adding more |
||
| }; | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
| // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors | ||
| // | ||
| // SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
|
|
@@ -13,7 +13,19 @@ | |
| #include "core/test/utils.hpp" | ||
|
|
||
|
|
||
| namespace { | ||
| class AllocationLogger : public gko::log::Logger { | ||
| public: | ||
| mutable int count = 0; | ||
|
|
||
| protected: | ||
| void on_allocation_completed(const gko::Executor* exec, | ||
| const gko::size_type& num_bytes, | ||
| const gko::uintptr& location) const override | ||
| { | ||
| std::cout << num_bytes << std::endl; | ||
| ++count; | ||
| } | ||
| }; | ||
|
|
||
|
|
||
| template <typename T> | ||
|
|
@@ -42,7 +54,7 @@ class ResidualNorm : public ::testing::Test { | |
| std::unique_ptr<typename gko::stop::ResidualNorm<T>::Factory> rhs_factory_; | ||
| std::unique_ptr<typename gko::stop::ResidualNorm<T>::Factory> rel_factory_; | ||
| std::unique_ptr<typename gko::stop::ResidualNorm<T>::Factory> abs_factory_; | ||
| std::shared_ptr<const gko::Executor> exec_; | ||
| std::shared_ptr<gko::Executor> exec_; | ||
| }; | ||
|
|
||
| TYPED_TEST_SUITE(ResidualNorm, gko::test::ValueTypes, TypenameNameGenerator); | ||
|
|
@@ -417,6 +429,111 @@ TYPED_TEST(ResidualNorm, SelfCalculatesAndWaitsTillResidualGoal) | |
| } | ||
|
|
||
|
|
||
| TYPED_TEST(ResidualNorm, SelfCalculatesWithoutReallocation) | ||
| { | ||
| using Mtx = typename TestFixture::Mtx; | ||
| using NormVector = typename TestFixture::NormVector; | ||
| using T = TypeParam; | ||
| using T_nc = gko::remove_complex<TypeParam>; | ||
| auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_); | ||
| auto system_mtx = share(gko::initialize<Mtx>({1.0}, this->exec_)); | ||
|
|
||
| T rhs_val = 10.0; | ||
| std::shared_ptr<gko::LinOp> rhs = | ||
| gko::initialize<Mtx>({rhs_val}, this->exec_); | ||
| auto rhs_criterion = this->rhs_factory_->generate(system_mtx, rhs, nullptr, | ||
| initial_res.get()); | ||
| auto rel_criterion = this->rel_factory_->generate(system_mtx, rhs, nullptr, | ||
| initial_res.get()); | ||
| auto abs_criterion = this->abs_factory_->generate(system_mtx, rhs, nullptr, | ||
| initial_res.get()); | ||
| { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either these blocks need a |
||
| auto solution = gko::initialize<Mtx>({rhs_val - T{10.0}}, this->exec_); | ||
| auto rhs_norm = gko::initialize<NormVector>({100.0}, this->exec_); | ||
| gko::as<Mtx>(rhs)->compute_norm2(rhs_norm); | ||
| constexpr gko::uint8 RelativeStoppingId{1}; | ||
| bool one_changed{}; | ||
| gko::array<gko::stopping_status> stop_status(this->exec_, 1); | ||
| stop_status.get_data()[0].reset(); | ||
| auto logger = std::make_shared<AllocationLogger>(); | ||
|
|
||
| ASSERT_FALSE(rhs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
|
|
||
| solution->at(0) = rhs_val - r<T>::value * T{1.1} * rhs_norm->at(0); | ||
| ASSERT_FALSE(rhs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); | ||
| ASSERT_EQ(one_changed, false); | ||
|
|
||
| solution->at(0) = rhs_val - r<T>::value * T{0.5} * rhs_norm->at(0); | ||
| this->exec_->add_logger(logger); | ||
| ASSERT_TRUE(rhs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); | ||
| ASSERT_EQ(one_changed, true); | ||
| ASSERT_EQ(logger->count, 0); | ||
| this->exec_->remove_logger(logger); | ||
| } | ||
| { | ||
| T initial_norm = 100.0; | ||
| auto solution = | ||
| gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_); | ||
| constexpr gko::uint8 RelativeStoppingId{1}; | ||
| bool one_changed{}; | ||
| gko::array<gko::stopping_status> stop_status(this->exec_, 1); | ||
| stop_status.get_data()[0].reset(); | ||
| auto logger = std::make_shared<AllocationLogger>(); | ||
|
|
||
| ASSERT_FALSE(rel_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
|
|
||
| solution->at(0) = rhs_val - r<T>::value * T{1.1} * initial_norm; | ||
| ASSERT_FALSE(rel_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); | ||
| ASSERT_EQ(one_changed, false); | ||
|
|
||
| solution->at(0) = rhs_val - r<T>::value * T{0.5} * initial_norm; | ||
| this->exec_->add_logger(logger); | ||
| ASSERT_TRUE(rel_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); | ||
| ASSERT_EQ(one_changed, true); | ||
| ASSERT_EQ(logger->count, 0); | ||
| this->exec_->remove_logger(logger); | ||
| } | ||
| { | ||
| auto solution = gko::initialize<Mtx>({rhs_val - T{100.0}}, this->exec_); | ||
| constexpr gko::uint8 RelativeStoppingId{1}; | ||
| bool one_changed{}; | ||
| gko::array<gko::stopping_status> stop_status(this->exec_, 1); | ||
| stop_status.get_data()[0].reset(); | ||
| auto logger = std::make_shared<AllocationLogger>(); | ||
|
|
||
| ASSERT_FALSE(abs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
|
|
||
| // TODO FIXME: NVHPC calculates different result of rhs - r*1.2 from | ||
| // rhs - tmp = rhs - (r * 1.2). https://godbolt.org/z/GrGE9PE67 | ||
| solution->at(0) = rhs_val - r<T>::value * T{1.4}; | ||
| ASSERT_FALSE(abs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); | ||
| ASSERT_EQ(one_changed, false); | ||
|
|
||
| solution->at(0) = rhs_val - r<T>::value * T{0.5}; | ||
| this->exec_->add_logger(logger); | ||
| ASSERT_TRUE(abs_criterion->update().solution(solution).check( | ||
| RelativeStoppingId, true, &stop_status, &one_changed)); | ||
| ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); | ||
| ASSERT_EQ(one_changed, true); | ||
| ASSERT_EQ(logger->count, 0); | ||
| this->exec_->remove_logger(logger); | ||
| } | ||
| } | ||
|
|
||
|
|
||
| TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS) | ||
| { | ||
| using Mtx = typename TestFixture::Mtx; | ||
|
|
@@ -1079,6 +1196,3 @@ TYPED_TEST(ResidualNormWithAbsolute, WaitsTillResidualGoalMultipleRHS) | |
| ASSERT_EQ(stop_status.get_data()[1].has_converged(), true); | ||
| ASSERT_EQ(one_changed, true); | ||
| } | ||
|
|
||
|
|
||
| } // namespace | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this the same as