2010年6月18日 星期五

Problem 442 Matrix Chain Multiplication,矩陣相乘

此題用字串做四則運算,所以要一步一步的讀取字串內容,再算出矩陣相乘的總次數。
矩陣 A 為 m1 * n1,矩陣 B 為 m2 * n2,假設 n1 = m2,則 A * B 的相乘次數為 m1 * n1 * m2 或 m1 * n2 * m2。

一開始先建構一個結構,裡面記錄矩陣字元、矩陣寬、矩陣長、相乘次數,C 語言程式碼如下:
struct matrix
{
char ch;
int row;
int col;
int count;
};
struct matrix m[26], stack[200];
接著在主程式讀入所有矩陣的資料:
scanf("%d", &n);
getchar();
for (i = 0; i < n; i ++)
{
scanf("%c %d %d", &ch, &row, &col);
getchar();
int index = ch - 'A';
m[index].ch = ch, m[index].row = row, m[index].col = col;
}
再來,就是開始計算的時候了,剛剛建構結構時,多見了一個 stack 陣列,那是要將一步一步讀取的資料堆在這個陣列。

一開始將 index 定義為 0,index 為 stack 陣列索引。
●如果遇到字元 '(',將 '(' 加入 stack, index 加 1。
●如果遇到字元為 A-Z,判斷 stack 的最上面索引的字元也為 A-Z,就將兩個矩陣相乘,並且判斷兩矩陣相乘會不會違反規則,若不會,則將相乘次數累計在最上面索引的相乘次數;若最上面索引的字元不為 A-Z,則將此次的矩陣資料加入 stack, index 加 1。
●如果遇到字元 ')',先判斷 stack 是否為空,為空則錯誤;再判斷最上面索引的字元若為 ')',則 index 減 1。最後判斷最上面索引的字元為 A-Z 和次上面索引的字元為 '(',就將最上面索引的資料下推,index 減 1,接著判斷次上面索引的字元為 A-Z,再將最上面和次上面的矩陣相乘,相乘次數累計在次上面索引,再將 index 減 1。

以上看不懂沒關係,提供程式碼最實在,由以上觀念寫成 C 語言程式碼如下:
while (gets(str))
{
int count = 0, error = 0, index = 0;
int row = 0, col = 0;
for (i = 0; (ch = str[i]); i ++)
{
if (ch == '(') stack[index ++].ch = ch;
if (isupper(ch))
{
int s = ch - 'A';
row = m[s].row, col = m[s].col;
if (isupper(stack[index - 1].ch))
{
if (stack[index - 1].col != row)
{ error = 1; break; }
stack[index - 1].count += stack[index - 1].row * row * col;
stack[index - 1].col = col;
}
else
{
stack[index].ch = ch;
stack[index].row = row;
stack[index].col = col;
stack[index].count = 0;
index ++;
}
}
if (ch == ')')
{
if (index - 1 < 0) { error = 1; break; }
if (stack[index - 1].ch == '(') index --;
if (index - 2 >= 0 &&
isupper(stack[index - 1].ch) && stack[index - 2].ch == '(')
{
stack[index - 2].ch = stack[index - 1].ch;
stack[index - 2].row = stack[index - 1].row;
stack[index - 2].col = stack[index - 1].col;
stack[index - 2].count = stack[index - 1].count;
index --;
if (isupper(stack[index - 2].ch))
{
row = stack[index - 1].row;
col = stack[index - 1].col;
count = stack[index - 1].count;

if (stack[index - 2].col != row)
{ error = 1; break; }
stack[index - 2].count += (count + stack[index - 2].row * row * col);
stack[index - 2].col = col;
index --;
}
}
}
}
if (error == 0) printf("%d\n", stack[0].count);
else if (error) printf("error\n");
}

By David.K

p442題目連結
回ACM題庫目錄
回首頁

沒有留言: