引言:
无意间看到国外一个网站写的Matrix类,实现了加减乘除基本运算以及各自的const版本等等,功能还算比较完善,,于是记录下来,以备后用:
1 #ifndef MATRIX_H 2 #define MATRIX_H 3 4 #include5 #include 6 #include 7 8 using namespace std; 9 10 template 11 class Matrix 12 { 13 public: 14 Matrix(); //默认构造函数 15 template 16 Matrix(InputIterator begin, InputIterator end); //用2迭代器构造 17 ~Matrix(); //析构函数 18 19 ElemType & at(size_t row, size_t col); //获取某个元素(引用) 20 const ElemType & at(size_t row, size_t col) const; 21 22 size_t numRows() const; //获得行数、列数 23 size_t numCols() const; 24 size_t size() const; //返回总元素数 25 26 //多维矩阵 27 class MutableReference; 28 class ImmutableReference; 29 MutableReference operator[] (size_t row); 30 ImmutableReference operator[] (size_t row) const; 31 32 typedef ElemType* iterator; //将元素类型指针定义为迭代器 33 typedef const ElemType* const_iterator; 34 35 iterator begin(); 36 iterator end(); 37 const_iterator begin() const; 38 const_iterator end() const; 39 40 iterator row_begin(size_t row); 41 iterator row_end(size_t row); 42 const_iterator row_begin(size_t row) const; 43 const_iterator row_end(size_t row) const; 44 45 Matrix& operator+= (const Matrix& rhs); 46 Matrix& operator-= (const Matrix& rhs); 47 Matrix& operator*= (const ElemType& scalar); 48 Matrix& operator/= (const ElemType& scalar); 49 50 //打印矩阵 51 void printMatrix(void) const; 52 private: 53 ElemType elems[Rows*Cols]; //矩阵元素的数组 54 55 }; 56 //两矩阵相加 57 template 58 const Matrix operator+ (const Matrix &lhs, const Matrix &rhs); 59 //两矩阵相减 60 template 61 const Matrix operator- (const Matrix &lhs, const Matrix &rhs); 62 //矩阵数乘(右乘) 63 template 64 const Matrix operator* (const Matrix &lhs, const T& scalar); 65 //矩阵数乘(左乘) 66 template 67 const Matrix operator* (const T& scalar, const Matrix &rhs); 68 //矩阵除以一个数 69 template 70 const Matrix operator/ (const Matrix & lhs,const T& scalar); 71 //一元运算的加减 相当于添加符号 72 template 73 const Matrix operator+ (const Matrix & operand); 74 template 75 const Matrix operator- (const Matrix & operand); 76 //2矩阵相乘 77 template 78 const Matrix operator*(const Matrix & lhs,const Matrix & rhs); 79 //矩阵的比较操作 80 template 81 bool operator== (const Matrix & lhs,const Matrix & rhs); 82 83 template 84 bool operator!= (const Matrix & lhs,const Matrix & rhs); 85 86 template 87 bool operator< (const Matrix & lhs,const Matrix & rhs); 88 89 template 90 bool operator<= (const Matrix & lhs,const Matrix & rhs); 91 92 template 93 bool operator>= (const Matrix & lhs,const Matrix & rhs); 94 95 template 96 bool operator> (const Matrix & lhs,const Matrix & rhs); 97 //是否为单位矩阵 98 template 99 Matrix Identity();100 //矩阵转置101 template 102 const Matrix Transpose(const Matrix & m);103 104 105 //106 /************************************************************************/107 /*108 函数的实现部分109 */110 /************************************************************************/111 //默认构造函数112 template 113 Matrix ::Matrix() 114 {115 }116 //迭代器构造函数117 template 118 template 119 Matrix ::Matrix(InputIterator rangeBegin, InputIterator rangeEnd) 120 {121 std::copy(rangeBegin, rangeEnd, begin());122 }123 //析构函数124 template 125 Matrix ::~Matrix()126 {127 }128 //得到row,col处元素(引用) 常量版本129 template 130 const T& Matrix ::at(size_t row, size_t col) const 131 {132 return *(begin() + row * numCols() + col);133 }134 //得到row,col处元素(引用) 非常量版本135 template 136 T& Matrix ::at(size_t row, size_t col) 137 {138 return const_cast (static_cast *>(this)->at(row, col));139 }140 //得到行数141 template 142 size_t Matrix ::numRows() const143 {144 return M;145 }146 template 147 size_t Matrix ::numCols() const148 {149 return N;150 }151 152 template 153 size_t Matrix ::size() const154 {155 return M*N;156 }157 //迭代器返回首地址指针 //注意返回是迭代器类型158 template 159 typename Matrix ::iterator Matrix ::begin()160 {161 return elems;162 }163 //迭代器返回首地址指针的常量版本164 template 165 typename Matrix ::const_iterator Matrix ::begin() const166 {167 return elems;168 }169 //尾迭代器获取170 template 171 typename Matrix ::iterator Matrix ::end()172 {173 return begin()+size();174 }175 //尾迭代器获取(常量版本)176 template 177 typename Matrix ::const_iterator Matrix ::end() const178 {179 return begin() + size();180 }181 //行迭代器(跳过指定元素获取)182 template 183 typename Matrix ::iterator Matrix ::row_begin(size_t row)184 {185 return begin() + row*numCols();186 }187 //行迭代器(跳过指定元素获取) 常量版本188 template 189 typename Matrix ::const_iterator Matrix ::row_begin(size_t row) const190 {191 return begin() + row*numCols();192 }193 //获得行尾迭代器194 template 195 typename Matrix ::iterator Matrix ::row_end(size_t row)196 {197 return row_begin(row) + N;198 }199 //获得行尾迭代器 const版本200 template 201 typename Matrix ::const_iterator Matrix ::row_end(size_t row) const202 {203 return row_begin(row) + N;204 }205 /************************************************************************/206 /*207 方括号[]操作返回引用的实现(非const版本)208 */209 /************************************************************************/210 template 211 class Matrix ::MutableReference212 {213 public:214 T& operator[] (size_t col)215 {216 return parent->at(row, col);217 }218 private:219 //私有构造函数 是获得此类实例的为例方法(有元类Matrix可以访问)220 MutableReference(Matrix* owner, size_t row) :parent(owner), row(row)221 {222 223 }224 friend class Matrix;225 const size_t row;226 Matrix *const parent;227 };228 /************************************************************************/229 /*230 方括号[]操作返回引用的实现(const版本)231 */232 /************************************************************************/233 template 234 class Matrix ::ImmutableReference235 {236 public:237 const T& operator[] (size_t col) const238 {239 return parent->at(row, col);240 }241 private:242 //私有构造函数 是获得此类实例的为例方法(有元类Matrix可以访问)243 ImmutableReference(const Matrix* owner, size_t row) :parent(owner), row(row)244 {245 246 }247 friend class Matrix;248 const size_t row;249 const Matrix *const parent;250 };251 //方括号返回引用的真真实现(用了上面的类)252 template 253 typename Matrix ::MutableReference Matrix ::operator [] (size_t row)254 {255 return MutableReference(this, row);256 }257 template 258 typename Matrix ::ImmutableReference Matrix ::operator [] (size_t row) const259 {260 return ImmutableReference(this, row);261 }262 /************************************************************************/263 /*264 复合运算符实现265 */266 /************************************************************************/267 template 268 Matrix & Matrix ::operator+= (const Matrix & rhs) 269 {270 std::transform(begin(), end(), // First input range is lhs271 rhs.begin(), // Start of second input range is rhs272 begin(), // Overwrite lhs273 std::plus ()); // Using addition274 return *this;275 }276 277 template 278 Matrix & Matrix ::operator-= (const Matrix & rhs)279 {280 std::transform(begin(), end(), // First input range is lhs281 rhs.begin(), // Start of second input range is rhs282 begin(), // Overwrite lhs283 std::minus ()); // Using subtraction284 return *this;285 }286 template 287 Matrix & Matrix ::operator*= (const T& scalar) 288 {289 std::transform(begin(), end(), // Input range is lhs290 begin(), // Output overwrites lhs291 std::bind2nd(std::multiplies (), scalar)); // Scalar mult.292 return *this;293 }294 template 295 Matrix & Matrix ::operator/= (const T& scalar) 296 {297 std::transform(begin(), end(), // Input range is lhs298 begin(), // Output overwrites lhs299 std::bind2nd(std::divides (), scalar)); // Divide by scalar300 return *this;301 }302 /************************************************************************/303 /*304 双目运算符实现305 */306 /************************************************************************/307 template 308 const Matrix operator+ (const Matrix & lhs,const Matrix & rhs) 309 {310 return Matrix (lhs) += rhs; //用到了复合运算符(成员函数)311 }312 template 313 const Matrix operator- (const Matrix & lhs,const Matrix & rhs) 314 {315 return Matrix (lhs) -= rhs;316 }317 template //(右乘一个数)318 const Matrix operator* (const Matrix & lhs,const T& scalar) 319 {320 return Matrix (lhs) *= scalar;321 }322 template //左乘一个数323 const Matrix operator* (const T& scalar,const Matrix & rhs) 324 {325 return Matrix (rhs) *= scalar; 326 }327 template 328 const Matrix operator/ (const Matrix & lhs,const T& scalar) 329 {330 return Matrix (lhs) /= scalar;331 }332 //一元运算符+333 template 334 const Matrix operator+ (const Matrix & operand) {335 return operand;336 }337 //一元运算符-338 template 339 const Matrix operator- (const Matrix & operand) 340 {341 return Matrix (operand) *= T(-1);342 }343 //2矩阵相乘344 template 345 const Matrix operator*(const Matrix & one,const Matrix & two) 346 {347 /* Create a result matrix of the right size and initialize it to zero. */348 Matrix result;349 std::fill(result.begin(), result.end(), T(0)); //初始化结果变量350 351 /* Now go fill it in. */352 for (size_t row = 0; row < result.numRows(); ++row)353 for (size_t col = 0; col < result.numCols(); ++col)354 for (size_t i = 0; i < N; ++i)355 result[row][col] += one[row][i] * two[i][col];356 357 return result;358 }359 //matrix1*=matrix运算实现360 template 361 Matrix & operator*= (Matrix & lhs,const Matrix & rhs) 362 {363 return lhs = lhs * rhs; // Nothing fancy here.364 }365 //比较运算符实现366 template 367 bool operator== (const Matrix & lhs,const Matrix & rhs) 368 {369 return std::equal(lhs.begin(), lhs.end(), rhs.begin());370 }371 template 372 bool operator!= (const Matrix & lhs,const Matrix & rhs) 373 {374 return !(lhs == rhs); //用了==运算符375 }376 /* The less-than operator uses the std::mismatch algorithm to chase down377 * the first element that differs in the two matrices, then returns whether378 * the lhs element is less than the rhs element. This is essentially a379 * lexicographical comparison optimized on the assumption that the two380 * sequences have the same size.381 */382 //小于运算符383 template 384 bool operator< (const Matrix & lhs,const Matrix & rhs) 385 {386 /* Compute the mismatch. */387 std::pair ::const_iterator,388 typename Matrix ::const_iterator> disagreement =389 std::mismatch(lhs.begin(), lhs.end(), rhs.begin());390 391 /* lhs < rhs only if there is a mismatch and the lhs's element is392 * lower than the rhs's element.393 */394 return disagreement.first != lhs.end() &&395 *disagreement.first < *disagreement.second;396 }397 398 /* The remaining relational operators are implemented in terms of <. */399 template 400 bool operator<= (const Matrix & lhs, const Matrix & rhs)401 {402 /* x <= y iff !(x > y) iff !(y < x) */403 return !(rhs < lhs);404 }405 template 406 bool operator>= (const Matrix & lhs, const Matrix & rhs)407 {408 /* x >= y iff !(y > x) iff !(x < y) */409 return !(lhs < rhs);410 }411 template 412 bool operator>(const Matrix & lhs, const Matrix & rhs)413 {414 /* x > y iff y < x */415 return !(rhs < lhs);416 }417 418 /* Transposition is reasonably straightforward. */ //转置419 template 420 const Matrix Transpose(const Matrix & m)421 {422 Matrix result;423 for (size_t row = 0; row < m.numRows(); ++row)424 for (size_t col = 0; col < m.numCols(); ++col)425 result[col][row] = m[row][col];426 return result;427 }428 429 /* Identity matrix just fills in the diagonal. */430 template Matrix Identity()431 {432 Matrix result;433 for (size_t row = 0; row < result.numRows(); ++row)434 for (size_t col = 0; col < result.numCols(); ++col)435 result[row][col] = (row == col ? T(1) : T(0));436 return result;437 }438 template 439 void Matrix ::printMatrix(void) const440 {441 for (size_t i = 0; i < this->numRows();++i)442 {443 for (size_t j = 0; j < this->numCols();++j)444 {445 cout << this->at(i, j)<<" ";446 }447 cout << endl;448 }449 }450 455 #endif
测试代码:
原站上并没有测试代码,为了验证类的正确性,自己写了一个简单的测试代码,仅供参考:
1 #include2 #include 3 #include "matrix.h" 4 5 using namespace std; 6 7 8 void testMatrixClass(); 9 10 int main()11 {12 testMatrixClass();13 14 return 0;15 }16 void testMatrixClass()17 {18 vector vec1,vec2;19 for (int i = 0; i < 6;++i)20 {21 vec1.push_back(i);22 }23 for (int i = 0; i < 12; ++i)24 {25 vec2.push_back(i+1);26 }27 vector ::iterator itBegin = vec1.begin();28 vector ::iterator itEnd = vec1.end();29 30 31 Matrix<2, 3, int>m_matrix1(itBegin,itEnd ); //用迭代器构造矩阵对象32 Matrix<3, 4, int>m_matrix2(vec2.begin(),vec2.end());33 cout << "---------Matrix 1 = :-----------------" << endl;34 m_matrix1.printMatrix();35 cout << "---------Matrix 2 = :-----------------" << endl;36 m_matrix2.printMatrix();37 cout << "-----matrix1(1,1) (从0开始)= " << m_matrix1.at(1, 1) << endl;38 cout << "---matrix1's size = " << m_matrix1.size() << " rows = " << m_matrix1.numRows()39 << " cols = " << m_matrix1.numCols() << endl;40 cout << "----matrix1 *3 = " << endl;41 (m_matrix1 *= 3).printMatrix();42 cout << "----matrix1 * matrix 2 = " << endl;43 Matrix<2, 4, int> result;44 result = m_matrix1*m_matrix2;45 result.printMatrix();46 47 }