##// END OF EJS Templates
Fixed erros in sfb, idualtree, irls_dn, and irls_dn2
yamaro -
r20:21
parent child
Show More
@@ -310,30 +310,32
310 310 chi2 = np.sum((es/sigma)**2);
311 311
312 312
313 # CS inversion using Iteratively Reweighted Least Squares (IRLS)-------------
313 # --------- CS inversion using Iteratively Reweighted Least Squares (IRLS) -------------
314 314
315 315 # (Use Nr, thetar, gnz, and Hr from MaxEnt above)
316 316
317 Psi = deb4_basis(Nr); ###### REPLACED BY LINEs BELOW (?)
318
319 print 'FINALLY!'
320 print Psi.shape
321
322 # REMOVE THIS?--------------------------------
317 Psi = deb4_basis(Nr);
318
319 # REMOVE THIS--------------------------------
323 320 #wavelet1 = pywt.Wavelet('db4')
324 321 #Phi, Psi, x = wavelet1.wavefun(level=3)
325 322 # --------------------------------------------
326 323
327 324 # add "sum to 1" constraint
328 # H2 = np.concatenate( (Hr, np.ones(shape=(1,Nr))), axis=0 );
329 # N_temp = np.array([[Nr/Nt]]);
330 # g2 = np.concatenate( (gnz, N_temp), axis=0 );
331 # H2 = H2.T.conj();
332 #
333 # print 'H2 shape', H2.shape
334 # print 'Psi shape', Psi.shape
335 #
336 # s = irls_dn2(np.dot(H2,Psi),g2,0.5,G);
325 H2 = np.concatenate( (Hr, np.ones(shape=(1,Nr))), axis=0 );
326 N_temp = np.array([[Nr/Nt]]);
327 g2 = np.concatenate( (gnz, N_temp), axis=0 );
328
329 #H2 = H2.T.conj();
330
331 #Psi = Psi.T.conj(); # to align matrices
332
333 ####print 'H2 shape', H2.shape
334 #####print 'Psi shape', Psi.shape
335
336 A = np.dot(H2,Psi);
337
338 s = irls_dn2(np.dot(H2,Psi),g2,0.5,G);
337 339 # f_cs = Psi*s;
338 340 #
339 341 # # plot
@@ -345,8 +347,6
345 347
346 348 # # # Scaling and shifting
347 349 # # # Only necessary for capon solution
348
349
350 350 f_capon = f_capon/np.max(f_capon)*np.max(fact);
351 351
352 352
@@ -25,8 +25,6
25 25 N = x.size;
26 26 n = np.arange(N);
27 27 n = np.mod(n-m, N);
28
29 print x.shape
30 28
31 29 y = x[0,n];
32 30
@@ -13,8 +13,7
13 13 def deb4_basis(N):
14 14
15 15 Psi = np.zeros(shape=(N,2*N+1));
16 idx = 1;
17
16 idx = 0;
18 17 J = 4;
19 18 [Faf, Fsf] = FSfarras();
20 19 [af, sf] = dualfilt1();
@@ -22,7 +21,6
22 21 # compute transform of zero vector
23 22 x = np.zeros(shape=(1,N));
24 23 w = dualtree(x, J, Faf, af);
25
26 24
27 25 # Uses both real and imaginary wavelets
28 26 for i in range (0, J):
@@ -36,6 +34,6
36 34 idx = idx + 1;
37 35
38 36 # Add uniform vector (seems to be useful if there's a background
39 Psi[:,2*N+1] = 1/np.sqrt(N);
37 Psi[:,2*N] = 1/np.sqrt(N);
40 38
41 39 return Psi No newline at end of file
@@ -36,7 +36,7
36 36 y2 = w[J][1];
37 37
38 38 for j in range (J-1, 0, -1):
39 y2 = sfb(y2, w[j][2], sf[0,1]);
39 y2 = sfb(y2, w[j][1], sf[0,1]);
40 40
41 41 y2 = sfb(y2, w[0][1], Fsf[0,1]);
42 42
@@ -18,16 +18,16
18 18 # and http://web.eecs.umich.edu/~aey/sparse/sparse11.pdf)
19 19
20 20 # Note to self: I found that "warm-starting" didn't really help too much.
21
21
22 22 [M,N] = A.shape;
23 23 # Initialize and precompute:
24 24 eps = 1e-2; # damping parameter
25 [Q,R] = linalg.qr(A.T.conj(),0);
26 print A.shape
27 print R.shape
28 print b.shape
25
26 [Q,R] = linalg.qr(A.T.conj(), mode='economic');
27
28
29 29 c = linalg.solve(R.T.conj(),b); # will be used later also
30 u = Q*c; # minimum 2-norm solution
30 u = np.dot(Q,c); # minimum 2-norm solution
31 31 I = sps.eye(M);
32 32
33 33 #---------- not needed, defined above--------------
@@ -51,30 +51,33
51 51 # Empty temporary N x N matrix
52 52 temp = np.zeros(shape=(N,N))
53 53
54 k = 0
54 55 # Sparse matrix
55 for i in range (1, N):
56 for j in range (1,N):
56 for i in range (0, N):
57 for j in range (0,N):
57 58 if(i==j):
58 temp[i,j] = w
59 temp[i,j] = w[k]
60 k = k+1
59 61
60 62 # Compressed Sparse Matrix
61 63 W = sps.csr_matrix(temp); #Compressed Sparse Row matrix
62 64
63 65
64 66 WAT = W*A.T.conj();
65 u_new = WAT * ( linalg.solve (A*WAT + lambda1*I), b);
67
68 u_new = np.dot(WAT , linalg.solve(np.dot(A,WAT) + np.dot(lambda1,I), b));
66 69
67 70 # See if this subproblem is converging
68 71 delu = np.linalg.norm(u_new-u)/np.linalg.norm(u);
69 epschange = delu < (np.sqrt(eps)/100);
72 epschange = delu < (np.sqrt(eps)/100.0);
70 73
71 74 # Make update
72 75 u = u_new;
73 76
74 77
75 eps = eps/10; # decrease eps
78 eps = eps/10.0; # decrease eps
76 79 # Print info
77 print 'eps =',eps;
80 #print 'eps =',eps;
78 81
79 82 return u
80 83
@@ -6,6 +6,8
6 6
7 7 from irls_dn import *
8 8 from scipy.optimize import fsolve
9 import numpy as np
10 from scipy.optimize import root
9 11
10 12 def irls_dn2(A,b,p,G):
11 13
@@ -22,23 +24,29
22 24
23 25 # Line Search
24 26
25 alpha = 2; # Line search parameter
27 alpha = 2.0; # Line search parameter
26 28 lambda1 = 1e5; # What's a reasonable but safe initial guess?
27 29 u = irls_dn(A,b,p,lambda1);
28 fid = np.norm(A*u-b)**2;
30 fid = np.linalg.norm(np.dot(A,u)-b)**2;
29 31
30 32 print '----------------------------------\n';
31 33
32 34 while (fid >= G):
33 35 lambda1 = lambda1 / alpha; # Balance between speed and accuracy
34 36 u = irls_dn(A,b,p,lambda1);
35 fid = np.norm(A*u-b)**2;
36 print 'lambda1 = #2e \t ||A*u-b||^2 = #.1f\n',lambda1,fid;
37 fid = np.linalg.norm(np.dot(A,u)-b)**2;
38 print 'lambda = %2e \t' % lambda1, '||A*u-b||^2 = %.1f\n' % fid;
37 39
38 40 # Refinement using fzero
39 41 lambda0 = np.array([lambda1,lambda1*alpha]); # interval with zero-crossing
40
41 f = lambda lambda1: np.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G;
42
43 def myfun(lambda1):
44 print "A = ", A.shape
45 print "b = ", b.shape
46 lambda1
47 return np.linalg.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G;
48
49 #f = lambda lambda1: np.linalg.norm(A*irls_dn(A,b,p,lambda1) - b)**2 - G; NOOOOOO
42 50
43 51
44 52 # opts = optimset('fzero');
@@ -46,7 +54,20
46 54 # opts.Display = 'none';
47 55 # opts.TolX = 0.01*lambda1;
48 56
49 lambda1 = fsolve(f,lambda0); # FALTA OPTIMIZE ESTO
57 #sol1 = fsolve(myfun,lambda0.ravel(), args=(), xtol=1e-14, maxfev=100000);
58 print "tolerancia=", 0.01*lambda1
59
60 #lambda1 = root(myfun,lambda0, method='krylov', tol=0.01*lambda1);
61
62
63 print "lamda1=", lambda1
64 print "lambda0=", lambda0
65
66 lambda1 = fsolve(myfun,lambda0); # FALTA OPTIMIZE ESTO
67
68 print "A = ", A.shape
69 print "b = ", b.shape
70 print "lambda1=", lambda1.shape
50 71
51 72 u = irls_dn(A,b,p,lambda1);
52 73
@@ -29,27 +29,19
29 29
30 30 N = 2*lo.size;
31 31 L = sf.size/2;
32 #print 'N', N
33 #print 'sf', sf
34
35
36 #print 'sf[:,0]', sf[:,0].shape
37 #print 'sf[:,1]', sf[:,1].shape
38 #print 'sbf hi', hi.shape
39
40
41 32
42 33 # Need to change format for upfirdn funct:
43 34 lo = lo.T.conj()
44 35 lo = lo.reshape(lo.size)
45 36
46 print 'sfb hi', hi
37 #print 'sfb hi', hi
47 38
48 39 # Need to change format for upfirdn funct:
49 40 hi = hi.T.conj()
50 41 hi = hi.reshape(hi.size)
51 42
52 43 #hi = hi.reshape(1, hi.size)
44
53 45
54 46 lo = upfirdn(lo, sf[:,0], 2, 1);
55 47 hi = upfirdn(hi, sf[:,1], 2, 1);
@@ -57,10 +49,10
57 49 y[0:L-1] = y[0:L-1] + y[N+ np.arange(0,L-1)]; #CHECK IF ARANGE IS CORRECT
58 50 y = y[0:N];
59 51
60 print 'y en sbf\n', y.shape
52 #print 'y en sbf\n', y.shape
61 53
62 54 y = y.reshape(1, y.size)
63 print 'y en sbf\n', y.shape
55 #print 'y en sbf\n', y.shape
64 56
65 57 y = cshift(y, 1-L/2);
66 58
General Comments 0
You need to be logged in to leave comments. Login now