Thursday, September 18, 2008

Matrix Chain Multiplication

经典教科书题
N个矩阵,N+1个维数

Matrix: 0 ...... N-1
-----------+---------------------------------------------------------
Matrix dim:| (M0 x M1) x (M1 x M2) x (M2 x M3) x ... x (M_N-1 x M_N)
-----------+---------------------------------------------------------
Matrix | 0 1 2 N-1
---------------------------------------------------------------------

The i-th matrix's dimension: M_(i-1) x M_i
第i到j个矩阵相乘的代价

d[i][i] = 0
d[i][j] = min(d[i][j], d[i][k]+d[k+1][j]+ M_(i-1)*M_(k)*M_(j))
int C[N+1][N+1];
int K[N+1][N+1];
void mcm(vector<int> &M, int n) {
for(int i = 1; i <= n; ++i)
C[i][i] = 0;
for(int d = 2; d <= n; ++d) {
for(int i = 1; i <= n-d+1; ++i) {
int j = i + d - 1;
C[i][j] = INF;
for(int k = i; k <= j-1; ++k) {
int tmp = C[i][k] + C[k+1][j] + M[i-1] * M[k] * M[j];
if( tmp < C[i][j]) {
K[i][j] = k;
C[i][j] = tmp;
}
}
}
}
}
C[1][N]中存着从1到N连乘最小的代价,从K[][]可以得到路径。
void print_opt(int i, int j) {
if(i == j)
cout << "A" << i;
else {
cout << "(";
print_opt(i, K[i][j]);
cout << " x ";
print_opt(K[i][j]+1, j);
cout << ")";
}
}
调用时传入mcm()一个长度为N+1的数组,和N(不是N+1!)
vector<int> M(n+1);
mcm(M, n);
下面是memoization的版本,编程简单
int lookupMcm(vector<int> &M, int i, int j) {
if(C[i][j] != INF) return C[i][j];
if(i == j)
return C[i][j] = 0;
for(int k = i; k < j; ++k) {
int q = lookupMcm(M, i, k) + lookupMcm(M, k+1, j) + M[i-1] * M[k] * M[j];
if( C[i][j] > q)
C[i][j] = q;
}
return C[i][j];
}

调用前将C[][]初始化为INF
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= n; ++j)
C[i][j] = INF;

printf("%d\n", lookupMcm(M, 1, n));
Similarly, this technique can be used to solove the following problem: (in today's interview (oct 21) )

find the length of the longest regular brackets sequence that is a subsequence of s

主要思想就是从中间出发,往两边发展,

如果已和dp[i][j]则考虑 dp[i-1][j+1]

if  s[i-1] match s[j+1], then dp[i-1][j+1] = dp[i][j]

else  dp[i-1][j+1] = max (dp[i-1][k] + dp[k+1][j+1])

#include <cstdio>
#include <cstring>

const int MAX = 101;

// dynamic programming 2d array
int dp[MAX][MAX];

// return true if find matching brackets
inline bool match(char a, char b) {
if(a == '(' && b == ')') return true;
if(a == '[' && b == ']') return true;
return false;
}

int main() {
  char buff[MAX];

// terminate until "end"
while (gets(buff) && buff[0] != 'e') {
int n = strlen(buff);
memset(dp, 0, sizeof(dp));
for (int i = 0; i < n; ++ i)
if (match(buff[i],buff[i+1]))
dp[i][i+1] = 2;

for (int k = 2; k < n; ++ k) {
for (int i = 0; i < n; ++ i) {
if (k + i < n) {
if (match(buff[i], buff[i+k]))
dp[i][i+k] = dp[i+1][i+k-1] + 2;
for (int j = i; j < i + k; ++ j) {
if (dp[i][j] + dp[j+1][i+k] > dp[i][i+k])
dp[i][i+k] = dp[j+1][i+k] + dp[i][j];
}
}
}
}
printf("%d\n", dp[0][n-1]);
}
return 0;
}