前言
这个实验主要用来测试大家对现代 C++ 的掌握程度,实验要求如下:
简单翻译一下上述要求,就是我们需要实现定义在 src/include/primer/p0_starter.h
中的三个类 Matrix
、RowMatrix
和 RowMatrixOperations
,其中 Matrix
是 RowMatrix
的父类,RowMatrixOperations
定义了三个用于数组运算的成员函数:Add
、Multiply
和 GEMM
(就是 (boldsymbol{A}*boldsymbol{B} + boldsymbol{C}))。
代码实现
Matrix 类
抽象基类 Matrix
需要我们编写的代码很少,只要完成构造函数和析构函数即可,下面省略了一些不需要我们写的代码:
template <typename T> class Matrix { protected: /** * * Construct a new Matrix instance. * @param rows The number of rows * @param cols The number of columns * */ Matrix(int rows, int cols) : rows_(rows), cols_(cols), linear_(new T[rows * cols]) {} int rows_; int cols_; T *linear_; public: /** * Destroy a matrix instance. * TODO(P0): Add implementation */ virtual ~Matrix() { delete[] linear_; } };
linear_
指向一个由二维矩阵展平而得的一维数组,里面共有 rows * cols
个类型为 T
的元素。由于我们在堆上分配数组的空间使用的是 new T[]
,所以删除的时候也得用 delete[]
。
RowMatrix 类
这个类用于表示二维矩阵,需要实现父类 Matrix
中的所有纯虚函数,为了方便访问数据元素,RowMatrix
多定义了一个指针数组 data_
,里面的每个元素分别指向了二维矩阵每行首元素的地址:
template <typename T> class RowMatrix : public Matrix<T> { public: /** * Construct a new RowMatrix instance. * @param rows The number of rows * @param cols The number of columns */ RowMatrix(int rows, int cols) : Matrix<T>(rows, cols) { data_ = new T *[rows]; for (int i = 0; i < rows; ++i) { data_[i] = &this->linear_[i * cols]; } } /** * @return The number of rows in the matrix */ auto GetRowCount() const -> int override { return this->rows_; } /** * @return The number of columns in the matrix */ auto GetColumnCount() const -> int override { return this->cols_; } /** * Get the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @return The (i,j)th matrix element * @throws OUT_OF_RANGE if either index is out of range */ auto GetElement(int i, int j) const -> T override { if (i < 0 || i >= GetRowCount() || j < 0 || j >= GetColumnCount()) { throw Exception(ExceptionType::OUT_OF_RANGE, "The index out of range"); } return data_[i][j]; } /** * Set the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @param val The value to insert * @throws OUT_OF_RANGE if either index is out of range */ void SetElement(int i, int j, T val) override { if (i < 0 || i >= GetRowCount() || j < 0 || j >= GetColumnCount()) { throw Exception(ExceptionType::OUT_OF_RANGE, "The index out of range"); } data_[i][j] = val; } /** * Fill the elements of the matrix from `source`. * * Throw OUT_OF_RANGE in the event that `source` * does not contain the required number of elements. * * @param source The source container * @throws OUT_OF_RANGE if `source` is incorrect size */ void FillFrom(const std::vector<T> &source) override { if (static_cast<int>(source.size()) != GetRowCount() * GetColumnCount()) { throw Exception(ExceptionType::OUT_OF_RANGE, "The number of elements of `source` is different from matrix"); } for (int i = 0; i < GetRowCount(); ++i) { for (int j = 0; j < GetColumnCount(); ++j) { data_[i][j] = source[i * GetColumnCount() + j]; } } } /** * Destroy a RowMatrix instance. */ ~RowMatrix() override { delete[] data_; } private: T **data_; };
需要注意的是,在 RowMatrix
中访问基类部分的成员(非虚函数)时需要加上 this
指针,不然编译时会报错说找不到指定的成员。
RowMatrixOperations 类
实现该类的三个成员函数之前应该检查数据维度是否匹配,不匹配就返回空指针,否则开个循环遍历二维矩阵完成相关操作即可:
template <typename T> class RowMatrixOperations { public: /** * Compute (`matrixA` + `matrixB`) and return the result. * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @return The result of matrix addition */ static auto Add(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> { if (matrixA->GetRowCount() != matrixB->GetRowCount() || matrixA->GetColumnCount() != matrixB->GetColumnCount()) { return std::unique_ptr<RowMatrix<T>>(nullptr); } auto rows = matrixA->GetRowCount(); auto cols = matrixA->GetColumnCount(); auto matrix = std::make_unique<RowMatrix<T>>(rows, cols); for (int i = 0; i < rows; ++i) { for (int j = 0; j < cols; ++j) { matrix->SetElement(i, j, matrixA->GetElement(i, j) + matrixB->GetElement(i, j)); } } return matrix; } /** * Compute the matrix multiplication (`matrixA` * `matrixB` and return the result. * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @return The result of matrix multiplication */ static auto Multiply(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> { if (matrixA->GetColumnCount() != matrixB->GetRowCount()) { return std::unique_ptr<RowMatrix<T>>(nullptr); } auto rows = matrixA->GetRowCount(); auto cols = matrixB->GetColumnCount(); auto matrix = std::make_unique<RowMatrix<T>>(rows, cols); for (int i = 0; i < rows; ++i) { for (int j = 0; j < cols; ++j) { T sum = 0; for (int k = 0; k < matrixA->GetColumnCount(); ++k) { sum += matrixA->GetElement(i, k) * matrixB->GetElement(k, j); } matrix->SetElement(i, j, sum); } } return matrix; } /** * Simplified General Matrix Multiply operation. Compute (`matrixA` * `matrixB` + `matrixC`). * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @param matrixC Input matrix * @return The result of general matrix multiply */ static auto GEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB, const RowMatrix<T> *matrixC) -> std::unique_ptr<RowMatrix<T>> { if (matrixA->GetColumnCount() != matrixB->GetRowCount()) { return std::unique_ptr<RowMatrix<T>>(nullptr); } if (matrixA->GetRowCount() != matrixC->GetRowCount() || matrixB->GetColumnCount() != matrixC->GetColumnCount()) { return std::unique_ptr<RowMatrix<T>>(nullptr); } return Add(Multiply(matrixA, matrixB).get(), matrixC); } };
测试
打开 test/primer/starter_test.cpp
,将各个测试用例里面的 DISABLED_
前缀移除,比如 TEST(StarterTest, DISABLED_SampleTest)
改为 TEST(StarterTest, SampleTest)
,之后运行下述命令:
mkdir build cd build cmake .. make starter_test ./test/starter_test
测试结果如下图所示:
总结
这次实验感觉比较简单,主要考察虚函数、模板和动态内存(包括智能指针)的知识,就是没搞明白为什么函数都用尾置返回类型,而且 Google 风格也让人很不习惯,缩进居然只有两格,函数居然开头大写。以上~~