KickJava   Java API By Example, From Geeks To Geeks.

Java > Open Source Codes > MatrixMultiply


1
2 import EDU.oswego.cs.dl.util.concurrent.*;
3
4 /**
5  * Divide and Conquer matrix multiply demo
6  **/

7
8 public class MatrixMultiply {
9
10   static final int DEFAULT_GRANULARITY = 16;
11
12   /** The quadrant size at which to stop recursing down
13    * and instead directly multiply the matrices.
14    * Must be a power of two. Minimum value is 2.
15    **/

16   static int granularity = DEFAULT_GRANULARITY;
17
18   public static void main(String JavaDoc[] args) {
19
20     final String JavaDoc usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";
21
22     try {
23       int procs;
24       int n;
25       try {
26         procs = Integer.parseInt(args[0]);
27         n = Integer.parseInt(args[1]);
28         if (args.length > 2) granularity = Integer.parseInt(args[2]);
29       }
30
31       catch (Exception JavaDoc e) {
32         System.out.println(usage);
33         return;
34       }
35
36       if ( ((n & (n - 1)) != 0) ||
37            ((granularity & (granularity - 1)) != 0) ||
38            granularity < 2) {
39         System.out.println(usage);
40         return;
41       }
42
43       float[][] a = new float[n][n];
44       float[][] b = new float[n][n];
45       float[][] c = new float[n][n];
46       init(a, b, n);
47
48       FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs);
49       g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
50       g.stats();
51
52       // check(c, n);
53
}
54     catch (InterruptedException JavaDoc ex) {}
55   }
56
57
58   // To simplify checking, fill with all 1's. Answer should be all n's.
59
static void init(float[][] a, float[][] b, int n) {
60     for (int i = 0; i < n; ++i) {
61       for (int j = 0; j < n; ++j) {
62         a[i][j] = 1.0F;
63         b[i][j] = 1.0F;
64       }
65     }
66   }
67
68   static void check(float[][] c, int n) {
69     for (int i = 0; i < n; i++ ) {
70       for (int j = 0; j < n; j++ ) {
71         if (c[i][j] != n) {
72           throw new Error JavaDoc("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
73         }
74       }
75     }
76   }
77
78   /**
79    * Multiply matrices AxB by dividing into quadrants, using algorithm:
80    * <pre>
81    * A x B
82    *
83    * A11 | A12 B11 | B12 A11*B11 | A11*B12 A12*B21 | A12*B22
84    * |----+----| x |----+----| = |--------+--------| + |---------+-------|
85    * A21 | A22 B21 | B21 A21*B11 | A21*B21 A22*B21 | A22*B22
86    * </pre>
87    */

88
89
90   static class Multiplier extends FJTask {
91     final float[][] A; // Matrix A
92
final int aRow; // first row of current quadrant of A
93
final int aCol; // first column of current quadrant of A
94

95     final float[][] B; // Similarly for B
96
final int bRow;
97     final int bCol;
98
99     final float[][] C; // Similarly for result matrix C
100
final int cRow;
101     final int cCol;
102
103     final int size; // number of elements in current quadrant
104

105     Multiplier(float[][] A, int aRow, int aCol,
106                float[][] B, int bRow, int bCol,
107                float[][] C, int cRow, int cCol,
108                int size) {
109       this.A = A; this.aRow = aRow; this.aCol = aCol;
110       this.B = B; this.bRow = bRow; this.bCol = bCol;
111       this.C = C; this.cRow = cRow; this.cCol = cCol;
112       this.size = size;
113     }
114
115     public void run() {
116
117       if (size <= granularity) {
118         multiplyStride2();
119       }
120
121       else {
122         int h = size / 2;
123
124         coInvoke(new FJTask[] {
125           seq(new Multiplier(A, aRow, aCol, // A11
126
B, bRow, bCol, // B11
127
C, cRow, cCol, // C11
128
h),
129               new Multiplier(A, aRow, aCol+h, // A12
130
B, bRow+h, bCol, // B21
131
C, cRow, cCol, // C11
132
h)),
133             
134           seq(new Multiplier(A, aRow, aCol, // A11
135
B, bRow, bCol+h, // B12
136
C, cRow, cCol+h, // C12
137
h),
138               new Multiplier(A, aRow, aCol+h, // A12
139
B, bRow+h, bCol+h, // B22
140
C, cRow, cCol+h, // C12
141
h)),
142           
143           seq(new Multiplier(A, aRow+h, aCol, // A21
144
B, bRow, bCol, // B11
145
C, cRow+h, cCol, // C21
146
h),
147               new Multiplier(A, aRow+h, aCol+h, // A22
148
B, bRow+h, bCol, // B21
149
C, cRow+h, cCol, // C21
150
h)),
151           
152           seq(new Multiplier(A, aRow+h, aCol, // A21
153
B, bRow, bCol+h, // B12
154
C, cRow+h, cCol+h, // C22
155
h),
156               new Multiplier(A, aRow+h, aCol+h, // A22
157
B, bRow+h, bCol+h, // B22
158
C, cRow+h, cCol+h, // C22
159
h))
160         });
161       }
162     }
163
164     /**
165      * Version of matrix multiplication that steps 2 rows and columns
166      * at a time. Adapted from Cilk demos.
167      * Note that the results are added into C, not just set into C.
168      * This works well here because Java array elements
169      * are created with all zero values.
170      **/

171
172     void multiplyStride2() {
173       for (int j = 0; j < size; j+=2) {
174         for (int i = 0; i < size; i +=2) {
175
176           float[] a0 = A[aRow+i];
177           float[] a1 = A[aRow+i+1];
178         
179           float s00 = 0.0F;
180           float s01 = 0.0F;
181           float s10 = 0.0F;
182           float s11 = 0.0F;
183
184           for (int k = 0; k < size; k+=2) {
185
186             float[] b0 = B[bRow+k];
187
188             s00 += a0[aCol+k] * b0[bCol+j];
189             s10 += a1[aCol+k] * b0[bCol+j];
190             s01 += a0[aCol+k] * b0[bCol+j+1];
191             s11 += a1[aCol+k] * b0[bCol+j+1];
192
193             float[] b1 = B[bRow+k+1];
194
195             s00 += a0[aCol+k+1] * b1[bCol+j];
196             s10 += a1[aCol+k+1] * b1[bCol+j];
197             s01 += a0[aCol+k+1] * b1[bCol+j+1];
198             s11 += a1[aCol+k+1] * b1[bCol+j+1];
199           }
200
201           C[cRow+i] [cCol+j] += s00;
202           C[cRow+i] [cCol+j+1] += s01;
203           C[cRow+i+1][cCol+j] += s10;
204           C[cRow+i+1][cCol+j+1] += s11;
205         }
206       }
207     }
208
209   }
210
211 }
212
Popular Tags