Strassen's algorithm is not the most efficient algorithm for matrix multiplication, but it was the first algorithm that was theoretically faster than the naive algorithm. There is very good explanation and implementation of Strassen's algorithm on Wikipedia.

However, the implementation of Strassen's algorithm cannot be used directly, because it  just sets the base case of the divide-and-conquer to be 1x1 matrix, which would consume huge time cost for iteration. If set the base case to 2x2 matrix, which means 2x2 matrix and 1x1 matrix will be multiplied by naive algorithm, then the Strassen's algorithm will be more efficient for matrices larger than 512x512.

When set base case to 2x2 matrix, then the Strassen's algorithm will surpass naive algorithm for matrices larger than 512x512.

When set base case to 2x2 matrix, then the Strassen's algorithm will surpass naive algorithm for matrices larger than 512x512.

When set base case to 6x6 matrix, then the Strassen's algorithm will surpass naive algorithm for matrices larger than 128x128.

When set base case to 6x6 matrix, then the Strassen's algorithm will surpass naive algorithm for matrices larger than 128x128.


/*------------------------------------------------------------------------------*/
// matrix_mult.cc -- Implementation of matrix multiplication with
// Strassen's algorithm.
//
// Compile this file with gcc command:
// g++ -Wall -o matrix_mult matrix_mult.cc 

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <ctype.h>
#include <unistd.h>
#include <iostream>
#include <fstream>
#include <cmath>
#include <cstring>

using namespace std;

// This function allocates the matrix
inline double** allocate_matrix(int n)
{
double** mat=new double*[n];
for(int i=0;i<n;++i)
{
mat[i]=new double[n];
memset(mat[i],0,sizeof(double)*n);
}

return (mat); // returns the pointer to the vector.
}

/*------------------------------------------------------------------------------*/
// This function unallocates the matrix (frees memory)
inline void free_matrix(double **M, int n)
{
for (int i = 0; i < n; i++)
{
delete [] M[i];
}

delete [] M; // frees the pointer /
M = NULL;
}

/*------------------------------------------------------------------------------*/
// function to sum two matrices
inline void sum(double **a, double **b, double **result, int tam) {

int i, j;

for (i = 0; i < tam; i++) {
for (j = 0; j < tam; j++) {
result[i][j] = a[i][j] + b[i][j];
}
}
}

/*------------------------------------------------------------------------------*/
// function to subtract two matrices
inline void subtract(double **a, double **b, double **result, int tam) {

int i, j;

for (i = 0; i < tam; i++) {
for (j = 0; j < tam; j++) {
result[i][j] = a[i][j] - b[i][j];
}
}
}

/*------------------------------------------------------------------------------*/
// naive method
void naive(double** A, double** B,double** C, int n)
{
for (int i=0;i<n;i++)
for (int j=0;j<n;j++)
for(int k=0;k<n;k++)
C[i][j] += A[i][k]*B[k][j];
}

/*------------------------------------------------------------------------------*/
// Strassen's method
void strassen(double **a, double **b, double **c, int tam)
{
// Key observation: call naive method for matrices smaller than 2 x 2
if(tam <= 4)
{
naive(a,b,c,tam);
return;
}

// other cases are treated here:
int newTam = tam/2;
double **a11, **a12, **a21, **a22;
double **b11, **b12, **b21, **b22;
double **c11, **c12, **c21, **c22;
double **p1, **p2, **p3, **p4, **p5, **p6, **p7;

// memory allocation:
a11 = allocate_matrix(newTam);
a12 = allocate_matrix(newTam);
a21 = allocate_matrix(newTam);
a22 = allocate_matrix(newTam);

b11 = allocate_matrix(newTam);
b12 = allocate_matrix(newTam);
b21 = allocate_matrix(newTam);
b22 = allocate_matrix(newTam);

c11 = allocate_matrix(newTam);
c12 = allocate_matrix(newTam);
c21 = allocate_matrix(newTam);
c22 = allocate_matrix(newTam);

p1 = allocate_matrix(newTam);
p2 = allocate_matrix(newTam);
p3 = allocate_matrix(newTam);
p4 = allocate_matrix(newTam);
p5 = allocate_matrix(newTam);
p6 = allocate_matrix(newTam);
p7 = allocate_matrix(newTam);

double **aResult = allocate_matrix(newTam);
double **bResult = allocate_matrix(newTam);

//dividing the matrices in 4 sub-matrices:
for (int i = 0; i < newTam; i++) {
for (int j = 0; j < newTam; j++) {
a11[i][j] = a[i][j];
a12[i][j] = a[i][j + newTam];
a21[i][j] = a[i + newTam][j];
a22[i][j] = a[i + newTam][j + newTam];

b11[i][j] = b[i][j];
b12[i][j] = b[i][j + newTam];
b21[i][j] = b[i + newTam][j];
b22[i][j] = b[i + newTam][j + newTam];
}
}

// Calculating p1 to p7:

sum(a11, a22, aResult, newTam); // a11 + a22
sum(b11, b22, bResult, newTam); // b11 + b22
strassen(aResult, bResult, p1, newTam); // p1 = (a11+a22) * (b11+b22)

sum(a21, a22, aResult, newTam); // a21 + a22
strassen(aResult, b11, p2, newTam); // p2 = (a21+a22) * (b11)

subtract(b12, b22, bResult, newTam); // b12 - b22
strassen(a11, bResult, p3, newTam); // p3 = (a11) * (b12 - b22)

subtract(b21, b11, bResult, newTam); // b21 - b11
strassen(a22, bResult, p4, newTam); // p4 = (a22) * (b21 - b11)

sum(a11, a12, aResult, newTam); // a11 + a12
strassen(aResult, b22, p5, newTam); // p5 = (a11+a12) * (b22)

subtract(a21, a11, aResult, newTam); // a21 - a11
sum(b11, b12, bResult, newTam); // b11 + b12
strassen(aResult, bResult, p6, newTam); // p6 = (a21-a11) * (b11+b12)

subtract(a12, a22, aResult, newTam); // a12 - a22
sum(b21, b22, bResult, newTam); // b21 + b22
strassen(aResult, bResult, p7, newTam); // p7 = (a12-a22) * (b21+b22)

// calculating c21, c21, c11 e c22:

sum(p3, p5, c12, newTam); // c12 = p3 + p5
sum(p2, p4, c21, newTam); // c21 = p2 + p4

sum(p1, p4, aResult, newTam); // p1 + p4
sum(aResult, p7, bResult, newTam); // p1 + p4 + p7
subtract(bResult, p5, c11, newTam); // c11 = p1 + p4 - p5 + p7

sum(p1, p3, aResult, newTam); // p1 + p3
sum(aResult, p6, bResult, newTam); // p1 + p3 + p6
subtract(bResult, p2, c22, newTam); // c22 = p1 + p3 - p2 + p6

// Grouping the results obtained in a single matrix:
for (int i = 0; i < newTam ; i++) {
for (int j = 0 ; j < newTam ; j++) {
c[i][j] = c11[i][j];
c[i][j + newTam] = c12[i][j];
c[i + newTam][j] = c21[i][j];
c[i + newTam][j + newTam] = c22[i][j];
}
}

// deallocating memory (free):
free_matrix(a11, newTam);
free_matrix(a12, newTam);
free_matrix(a21, newTam);
free_matrix(a22, newTam);

free_matrix(b11, newTam);
free_matrix(b12, newTam);
free_matrix(b21, newTam);
free_matrix(b22, newTam);

free_matrix(c11, newTam);
free_matrix(c12, newTam);
free_matrix(c21, newTam);
free_matrix(c22, newTam);

free_matrix(p1, newTam);
free_matrix(p2, newTam);
free_matrix(p3, newTam);
free_matrix(p4, newTam);
free_matrix(p5, newTam);
free_matrix(p6, newTam);
free_matrix(p7, newTam);
free_matrix(aResult, newTam);
free_matrix(bResult, newTam);

} // end of Strassen function

/*------------------------------------------------------------------------------*/
// Generate random matrices
void gen_matrix(double** M,int n)
{
for(int i=0;i<n;++i)
{
for(int j=0;j<n;++j)
{
M[i][j]=rand()%100;
//M[i][j]=1;
}
}
}

/*------------------------------------------------------------------------------*/
// print matrix M using specied fstream
void print_matrix(fstream& fs, double** M, int n)
{
for(int i=0;i<n;++i)
{
for(int j=0;j<n;++j)
{
fs<<M[i][j]<<" ";
}
fs<<endl;
}
fs<<endl;
}

/*------------------------------------------------------------------------------*/
// record the generated matrix and the final product
void mat_mult_log(double** A, double** B,double** C,int n,char* file)
{
fstream fs;
fs.open(file,fstream::out);

fs<<"Random Matrix A:"<<endl;
print_matrix(fs,A,n);
fs<<"Random Matrix B:"<<endl;
print_matrix(fs,B,n);
fs<<"C=A * B"<<endl;
print_matrix(fs,C,n);

fs.close();
}

/*------------------------------------------------------------------------------*/

int main(int argc, char** argv)
{
srand(time(NULL));

int mdim=2; // matrix dimension
char* output=NULL;
bool is_strassen=false;
int c;

while ((c = getopt (argc, argv, "sn:o:")) != -1)
{
switch (c)
{
case 's':
is_strassen=true;
break;
case 'n':
mdim = pow((int)2,atoi(optarg)); // 2^n dimensions
break;
case 'o':
output = optarg; // 2^n dimensions
break;
case '?':
if (optopt == 'n')
fprintf (stderr, "Option -%c requires an argument.\n", optopt);
else if (isprint (optopt))
fprintf (stderr, "Unknown option `-%c'.\n", optopt);
else
fprintf (stderr,
"Unknown option character `\x%x'.\n",
optopt);
return 1;
default:
abort ();
}
}

// create new matrices
double** A=allocate_matrix(mdim);
double** B=allocate_matrix(mdim);
double** C=allocate_matrix(mdim);
gen_matrix(A,mdim);
gen_matrix(B,mdim);

// matrices multiplication
if(is_strassen)
strassen(A,B,C,mdim);
else
naive(A,B,C,mdim);

if(output!=NULL)
mat_mult_log(A,B,C,mdim,output);

free_matrix(A,mdim);
free_matrix(B,mdim);
free_matrix(C,mdim);

return 0;
}