##// END OF EJS Templates
Fixed erros in sfb, idualtree, irls_dn, and irls_dn2
yamaro -
r20:21
parent child
Show More
@@ -310,30 +310,32
310 chi2 = np.sum((es/sigma)**2);
310 chi2 = np.sum((es/sigma)**2);
311
311
312
312
313 # CS inversion using Iteratively Reweighted Least Squares (IRLS)-------------
313 # --------- CS inversion using Iteratively Reweighted Least Squares (IRLS) -------------
314
314
315 # (Use Nr, thetar, gnz, and Hr from MaxEnt above)
315 # (Use Nr, thetar, gnz, and Hr from MaxEnt above)
316
316
317 Psi = deb4_basis(Nr); ###### REPLACED BY LINEs BELOW (?)
317 Psi = deb4_basis(Nr);
318
318
319 print 'FINALLY!'
319 # REMOVE THIS--------------------------------
320 print Psi.shape
321
322 # REMOVE THIS?--------------------------------
323 #wavelet1 = pywt.Wavelet('db4')
320 #wavelet1 = pywt.Wavelet('db4')
324 #Phi, Psi, x = wavelet1.wavefun(level=3)
321 #Phi, Psi, x = wavelet1.wavefun(level=3)
325 # --------------------------------------------
322 # --------------------------------------------
326
323
327 # add "sum to 1" constraint
324 # add "sum to 1" constraint
328 # H2 = np.concatenate( (Hr, np.ones(shape=(1,Nr))), axis=0 );
325 H2 = np.concatenate( (Hr, np.ones(shape=(1,Nr))), axis=0 );
329 # N_temp = np.array([[Nr/Nt]]);
326 N_temp = np.array([[Nr/Nt]]);
330 # g2 = np.concatenate( (gnz, N_temp), axis=0 );
327 g2 = np.concatenate( (gnz, N_temp), axis=0 );
331 # H2 = H2.T.conj();
328
332 #
329 #H2 = H2.T.conj();
333 # print 'H2 shape', H2.shape
330
334 # print 'Psi shape', Psi.shape
331 #Psi = Psi.T.conj(); # to align matrices
335 #
332
336 # s = irls_dn2(np.dot(H2,Psi),g2,0.5,G);
333 ####print 'H2 shape', H2.shape
334 #####print 'Psi shape', Psi.shape
335
336 A = np.dot(H2,Psi);
337
338 s = irls_dn2(np.dot(H2,Psi),g2,0.5,G);
337 # f_cs = Psi*s;
339 # f_cs = Psi*s;
338 #
340 #
339 # # plot
341 # # plot
@@ -345,8 +347,6
345
347
346 # # # Scaling and shifting
348 # # # Scaling and shifting
347 # # # Only necessary for capon solution
349 # # # Only necessary for capon solution
348
349
350 f_capon = f_capon/np.max(f_capon)*np.max(fact);
350 f_capon = f_capon/np.max(f_capon)*np.max(fact);
351
351
352
352
@@ -25,8 +25,6
25 N = x.size;
25 N = x.size;
26 n = np.arange(N);
26 n = np.arange(N);
27 n = np.mod(n-m, N);
27 n = np.mod(n-m, N);
28
29 print x.shape
30
28
31 y = x[0,n];
29 y = x[0,n];
32
30
@@ -13,8 +13,7
13 def deb4_basis(N):
13 def deb4_basis(N):
14
14
15 Psi = np.zeros(shape=(N,2*N+1));
15 Psi = np.zeros(shape=(N,2*N+1));
16 idx = 1;
16 idx = 0;
17
18 J = 4;
17 J = 4;
19 [Faf, Fsf] = FSfarras();
18 [Faf, Fsf] = FSfarras();
20 [af, sf] = dualfilt1();
19 [af, sf] = dualfilt1();
@@ -22,7 +21,6
22 # compute transform of zero vector
21 # compute transform of zero vector
23 x = np.zeros(shape=(1,N));
22 x = np.zeros(shape=(1,N));
24 w = dualtree(x, J, Faf, af);
23 w = dualtree(x, J, Faf, af);
25
26
24
27 # Uses both real and imaginary wavelets
25 # Uses both real and imaginary wavelets
28 for i in range (0, J):
26 for i in range (0, J):
@@ -36,6 +34,6
36 idx = idx + 1;
34 idx = idx + 1;
37
35
38 # Add uniform vector (seems to be useful if there's a background
36 # Add uniform vector (seems to be useful if there's a background
39 Psi[:,2*N+1] = 1/np.sqrt(N);
37 Psi[:,2*N] = 1/np.sqrt(N);
40
38
41 return Psi No newline at end of file
39 return Psi
@@ -36,7 +36,7
36 y2 = w[J][1];
36 y2 = w[J][1];
37
37
38 for j in range (J-1, 0, -1):
38 for j in range (J-1, 0, -1):
39 y2 = sfb(y2, w[j][2], sf[0,1]);
39 y2 = sfb(y2, w[j][1], sf[0,1]);
40
40
41 y2 = sfb(y2, w[0][1], Fsf[0,1]);
41 y2 = sfb(y2, w[0][1], Fsf[0,1]);
42
42
@@ -18,16 +18,16
18 # and http://web.eecs.umich.edu/~aey/sparse/sparse11.pdf)
18 # and http://web.eecs.umich.edu/~aey/sparse/sparse11.pdf)
19
19
20 # Note to self: I found that "warm-starting" didn't really help too much.
20 # Note to self: I found that "warm-starting" didn't really help too much.
21
21
22 [M,N] = A.shape;
22 [M,N] = A.shape;
23 # Initialize and precompute:
23 # Initialize and precompute:
24 eps = 1e-2; # damping parameter
24 eps = 1e-2; # damping parameter
25 [Q,R] = linalg.qr(A.T.conj(),0);
25
26 print A.shape
26 [Q,R] = linalg.qr(A.T.conj(), mode='economic');
27 print R.shape
27
28 print b.shape
28
29 c = linalg.solve(R.T.conj(),b); # will be used later also
29 c = linalg.solve(R.T.conj(),b); # will be used later also
30 u = Q*c; # minimum 2-norm solution
30 u = np.dot(Q,c); # minimum 2-norm solution
31 I = sps.eye(M);
31 I = sps.eye(M);
32
32
33 #---------- not needed, defined above--------------
33 #---------- not needed, defined above--------------
@@ -51,30 +51,33
51 # Empty temporary N x N matrix
51 # Empty temporary N x N matrix
52 temp = np.zeros(shape=(N,N))
52 temp = np.zeros(shape=(N,N))
53
53
54 k = 0
54 # Sparse matrix
55 # Sparse matrix
55 for i in range (1, N):
56 for i in range (0, N):
56 for j in range (1,N):
57 for j in range (0,N):
57 if(i==j):
58 if(i==j):
58 temp[i,j] = w
59 temp[i,j] = w[k]
60 k = k+1
59
61
60 # Compressed Sparse Matrix
62 # Compressed Sparse Matrix
61 W = sps.csr_matrix(temp); #Compressed Sparse Row matrix
63 W = sps.csr_matrix(temp); #Compressed Sparse Row matrix
62
64
63
65
64 WAT = W*A.T.conj();
66 WAT = W*A.T.conj();
65 u_new = WAT * ( linalg.solve (A*WAT + lambda1*I), b);
67
68 u_new = np.dot(WAT , linalg.solve(np.dot(A,WAT) + np.dot(lambda1,I), b));
66
69
67 # See if this subproblem is converging
70 # See if this subproblem is converging
68 delu = np.linalg.norm(u_new-u)/np.linalg.norm(u);
71 delu = np.linalg.norm(u_new-u)/np.linalg.norm(u);
69 epschange = delu < (np.sqrt(eps)/100);
72 epschange = delu < (np.sqrt(eps)/100.0);
70
73
71 # Make update
74 # Make update
72 u = u_new;
75 u = u_new;
73
76
74
77
75 eps = eps/10; # decrease eps
78 eps = eps/10.0; # decrease eps
76 # Print info
79 # Print info
77 print 'eps =',eps;
80 #print 'eps =',eps;
78
81
79 return u
82 return u
80
83
@@ -6,6 +6,8
6
6
7 from irls_dn import *
7 from irls_dn import *
8 from scipy.optimize import fsolve
8 from scipy.optimize import fsolve
9 import numpy as np
10 from scipy.optimize import root
9
11
10 def irls_dn2(A,b,p,G):
12 def irls_dn2(A,b,p,G):
11
13
@@ -22,23 +24,29
22
24
23 # Line Search
25 # Line Search
24
26
25 alpha = 2; # Line search parameter
27 alpha = 2.0; # Line search parameter
26 lambda1 = 1e5; # What's a reasonable but safe initial guess?
28 lambda1 = 1e5; # What's a reasonable but safe initial guess?
27 u = irls_dn(A,b,p,lambda1);
29 u = irls_dn(A,b,p,lambda1);
28 fid = np.norm(A*u-b)**2;
30 fid = np.linalg.norm(np.dot(A,u)-b)**2;
29
31
30 print '----------------------------------\n';
32 print '----------------------------------\n';
31
33
32 while (fid >= G):
34 while (fid >= G):
33 lambda1 = lambda1 / alpha; # Balance between speed and accuracy
35 lambda1 = lambda1 / alpha; # Balance between speed and accuracy
34 u = irls_dn(A,b,p,lambda1);
36 u = irls_dn(A,b,p,lambda1);
35 fid = np.norm(A*u-b)**2;
37 fid = np.linalg.norm(np.dot(A,u)-b)**2;
36 print 'lambda1 = #2e \t ||A*u-b||^2 = #.1f\n',lambda1,fid;
38 print 'lambda = %2e \t' % lambda1, '||A*u-b||^2 = %.1f\n' % fid;
37
39
38 # Refinement using fzero
40 # Refinement using fzero
39 lambda0 = np.array([lambda1,lambda1*alpha]); # interval with zero-crossing
41 lambda0 = np.array([lambda1,lambda1*alpha]); # interval with zero-crossing
40
42
41 f = lambda lambda1: np.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G;
43 def myfun(lambda1):
44 print "A = ", A.shape
45 print "b = ", b.shape
46 lambda1
47 return np.linalg.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G;
48
49 #f = lambda lambda1: np.linalg.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G; NOOOOOO
42
50
43
51
44 # opts = optimset('fzero');
52 # opts = optimset('fzero');
@@ -46,7 +54,20
46 # opts.Display = 'none';
54 # opts.Display = 'none';
47 # opts.TolX = 0.01*lambda1;
55 # opts.TolX = 0.01*lambda1;
48
56
49 lambda1 = fsolve(f,lambda0); # FALTA OPTIMIZE ESTO
57 #sol1 = fsolve(myfun,lambda0.ravel(), args=(), xtol=1e-14, maxfev=100000);
58 print "tolerancia=", 0.01*lambda1
59
60 #lambda1 = root(myfun,lambda0, method='krylov', tol=0.01*lambda1);
61
62
63 print "lamda1=", lambda1
64 print "lambda0=", lambda0
65
66 lambda1 = fsolve(myfun,lambda0); # FALTA OPTIMIZE ESTO
67
68 print "A = ", A.shape
69 print "b = ", b.shape
70 print "lambda1=", lambda1.shape
50
71
51 u = irls_dn(A,b,p,lambda1);
72 u = irls_dn(A,b,p,lambda1);
52
73
@@ -29,27 +29,19
29
29
30 N = 2*lo.size;
30 N = 2*lo.size;
31 L = sf.size/2;
31 L = sf.size/2;
32 #print 'N', N
33 #print 'sf', sf
34
35
36 #print 'sf[:,0]', sf[:,0].shape
37 #print 'sf[:,1]', sf[:,1].shape
38 #print 'sbf hi', hi.shape
39
40
41
32
42 # Need to change format for upfirdn funct:
33 # Need to change format for upfirdn funct:
43 lo = lo.T.conj()
34 lo = lo.T.conj()
44 lo = lo.reshape(lo.size)
35 lo = lo.reshape(lo.size)
45
36
46 print 'sfb hi', hi
37 #print 'sfb hi', hi
47
38
48 # Need to change format for upfirdn funct:
39 # Need to change format for upfirdn funct:
49 hi = hi.T.conj()
40 hi = hi.T.conj()
50 hi = hi.reshape(hi.size)
41 hi = hi.reshape(hi.size)
51
42
52 #hi = hi.reshape(1, hi.size)
43 #hi = hi.reshape(1, hi.size)
44
53
45
54 lo = upfirdn(lo, sf[:,0], 2, 1);
46 lo = upfirdn(lo, sf[:,0], 2, 1);
55 hi = upfirdn(hi, sf[:,1], 2, 1);
47 hi = upfirdn(hi, sf[:,1], 2, 1);
@@ -57,10 +49,10
57 y[0:L-1] = y[0:L-1] + y[N+ np.arange(0,L-1)]; #CHECK IF ARANGE IS CORRECT
49 y[0:L-1] = y[0:L-1] + y[N+ np.arange(0,L-1)]; #CHECK IF ARANGE IS CORRECT
58 y = y[0:N];
50 y = y[0:N];
59
51
60 print 'y en sbf\n', y.shape
52 #print 'y en sbf\n', y.shape
61
53
62 y = y.reshape(1, y.size)
54 y = y.reshape(1, y.size)
63 print 'y en sbf\n', y.shape
55 #print 'y en sbf\n', y.shape
64
56
65 y = cshift(y, 1-L/2);
57 y = cshift(y, 1-L/2);
66
58
General Comments 0
You need to be logged in to leave comments. Login now