Skip to content

Commit a0d77bd

Browse files
committed
[core] allow parsing type=matrix::Identity
1 parent d2e62da commit a0d77bd

File tree

7 files changed

+96
-8
lines changed

7 files changed

+96
-8
lines changed

core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ target_sources(
4848
components/range_minimum_query.cpp
4949
config/config.cpp
5050
config/config_helper.cpp
51+
config/matrix.cpp
5152
config/property_tree.cpp
5253
config/stop_config.cpp
5354
config/type_descriptor.cpp

core/config/config_helper.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ enum class LinOpFactoryType : int {
7171
Sor,
7272
Multigrid,
7373
Pgm,
74-
Schwarz
74+
Schwarz,
75+
Identity
7576
};
7677

7778

core/config/matrix.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <ginkgo/core/base/exception_helpers.hpp>
6+
#include <ginkgo/core/config/config.hpp>
7+
#include <ginkgo/core/config/registry.hpp>
8+
#include <ginkgo/core/matrix/identity.hpp>
9+
10+
#include "core/config/config_helper.hpp"
11+
#include "core/config/dispatch.hpp"
12+
#include "core/config/parse_macro.hpp"
13+
14+
15+
namespace gko {
16+
namespace config {
17+
18+
19+
template <typename ValueType>
20+
struct IdentityParser {
21+
static typename matrix::IdentityFactory<ValueType>::parameters_type parse(
22+
const pnode& config, const registry& context,
23+
const type_descriptor& td_for_child)
24+
{
25+
return {};
26+
}
27+
};
28+
29+
GKO_PARSE_VALUE_TYPE(Identity, IdentityParser);
30+
31+
32+
} // namespace config
33+
} // namespace gko

core/config/registry.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ configuration_map generate_config_map()
5353
{"solver::Multigrid", parse<LinOpFactoryType::Multigrid>},
5454
{"multigrid::Pgm", parse<LinOpFactoryType::Pgm>},
5555
#if GINKGO_BUILD_MPI
56-
{
57-
"preconditioner::Schwarz", parse<LinOpFactoryType::Schwarz>
58-
}
56+
{"preconditioner::Schwarz", parse<LinOpFactoryType::Schwarz>},
5957
#endif
58+
{"matrix::Identity", parse<LinOpFactoryType::Identity>},
6059
};
6160
}
6261

core/test/config/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
ginkgo_create_test(config)
22
ginkgo_create_test(factorization)
3+
ginkgo_create_test(matrix)
34
ginkgo_create_test(multigrid)
45
ginkgo_create_test(preconditioner)
56
ginkgo_create_test(property_tree)

core/test/config/matrix.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#include <typeinfo>
6+
7+
#include <gtest/gtest.h>
8+
9+
#include <ginkgo/core/base/exception.hpp>
10+
#include <ginkgo/core/base/executor.hpp>
11+
#include <ginkgo/core/config/config.hpp>
12+
#include <ginkgo/core/matrix/identity.hpp>
13+
14+
#include "cmake-build-debug/_deps/googletest-src/googletest/include/gtest/gtest-typed-test.h"
15+
#include "core/config/config_helper.hpp"
16+
#include "core/config/registry_accessor.hpp"
17+
#include "core/test/utils.hpp"
18+
19+
20+
using namespace gko::config;
21+
22+
23+
template <typename ValueType>
24+
class Identity : public ::testing::Test {
25+
protected:
26+
using value_type = ValueType;
27+
28+
std::shared_ptr<const gko::Executor> exec =
29+
gko::ReferenceExecutor::create();
30+
std::shared_ptr<const gko::matrix::Identity<value_type>> ans =
31+
gko::matrix::Identity<value_type>::create(exec, 4u);
32+
};
33+
34+
TYPED_TEST_SUITE(Identity, gko::test::ValueTypes, TypenameNameGenerator);
35+
36+
37+
TYPED_TEST(Identity, CanParse)
38+
{
39+
using value_type = typename TestFixture::value_type;
40+
auto config = pnode({{"type", pnode("matrix::Identity")}});
41+
42+
auto res = parse(config, {}, make_type_descriptor<value_type>())
43+
.on(this->exec)
44+
->generate(this->ans);
45+
46+
ASSERT_TRUE(dynamic_cast<gko::matrix::Identity<value_type>*>(res.get()));
47+
GKO_ASSERT_EQUAL_DIMENSIONS(res, this->ans);
48+
}

include/ginkgo/core/matrix/identity.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -88,11 +88,12 @@ class Identity : public EnableLinOp<Identity<ValueType>>, public Transposable {
8888
template <typename ValueType = default_precision>
8989
class IdentityFactory
9090
: public EnablePolymorphicObject<IdentityFactory<ValueType>, LinOpFactory> {
91-
friend class EnablePolymorphicObject<IdentityFactory, LinOpFactory>;
92-
9391
public:
9492
using value_type = ValueType;
9593

94+
struct parameters_type
95+
: enable_parameters_type<parameters_type, IdentityFactory> {};
96+
9697
/**
9798
* Creates a new Identity factory.
9899
*
@@ -108,10 +109,14 @@ class IdentityFactory
108109
}
109110

110111
protected:
112+
friend class EnablePolymorphicObject<IdentityFactory, LinOpFactory>;
113+
friend class enable_parameters_type<parameters_type, IdentityFactory>;
114+
111115
std::unique_ptr<LinOp> generate_impl(
112116
std::shared_ptr<const LinOp> base) const override;
113117

114-
IdentityFactory(std::shared_ptr<const Executor> exec)
118+
explicit IdentityFactory(std::shared_ptr<const Executor> exec,
119+
const parameters_type& params = {})
115120
: EnablePolymorphicObject<IdentityFactory, LinOpFactory>(exec)
116121
{}
117122
};

0 commit comments

Comments
 (0)