Skip to content

Commit

Permalink
add Schur decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
stylewarning committed Sep 9, 2022
1 parent 4f670ee commit c2a9d9e
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions magicl.asd
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
:components ((:file "lapack-generics")
(:file "lapack-templates")
(:file "lapack-bindings")
(:file "lapack-schur")
(:file "lapack-qz")
(:file "lapack-csd")))))

Expand Down
2 changes: 2 additions & 0 deletions src/extensions/lapack/lapack-generics.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

(magicl:define-extensible-function (magicl:svd lapack-svd :lapack) (matrix &key reduced))

(magicl:define-extensible-function (magicl:schur schur-extension :lapack) (matrix))

(magicl:define-extensible-function (magicl:qz qz-extension :lapack) (matrix1 matrix2))

(magicl:define-extensible-function (magicl:ql ql-extension :lapack) (matrix))
Expand Down
78 changes: 78 additions & 0 deletions src/extensions/lapack/lapack-schur.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
;;;; lapack-schur.lisp
;;;;
;;;; Author: Robert Smith

(in-package #:magicl-lapack)

(defmethod schur-extension ((a magicl:matrix/double-float))
(assert (magicl:square-matrix-p a))
;; TODO: This probably doesn't properly take into account the tensor
;; layout, etc.
(let* ((aa (magicl:deep-copy-tensor a))
(n (magicl:nrows a))
(ttr (make-array n :element-type 'double-float :initial-element 0.0d0))
(tti (make-array n :element-type 'double-float :initial-element 0.0d0))
(tt (magicl:zeros (list n n) :type '(complex double-float)))
(zz (magicl:zeros (magicl:shape a) :type 'double-float))
(lwork (* 3 n))
(info 0))
(flet ((arr (i &optional (ty 'double-float))
(make-array i :element-type ty)))
(declare (inline arr))
(magicl.lapack-cffi:%dgees
"V"
"N"
0
n
(magicl::storage aa)
n
0
ttr
tti
(magicl::storage zz)
n
(arr lwork)
lwork
(arr n '(signed-byte 32)) ; not referenced
info) ; INFO
;; TODO: we need to check info
(dotimes (i n)
(setf (magicl:tref tt i i) (complex (aref ttr i) (aref tti i))))
(values zz tt))))

(defmethod schur-extension ((a magicl:matrix/complex-double-float))
(assert (magicl:square-matrix-p a))
;; TODO: This probably doesn't properly take into account the tensor
;; layout, etc.
(let* ((aa (magicl:deep-copy-tensor a))
(n (magicl:nrows a))
(tt-diag (make-array n :element-type '(complex double-float)
:initial-element #C(0.0d0 0.0d0)))
(tt (magicl:zeros (list n n) :type '(complex double-float)))
(zz (magicl:zeros (magicl:shape a) :type '(complex double-float)))
(lwork (* 2 n))
(info 0))
(flet ((arr (i &optional (ty '(complex double-float)))
(make-array i :element-type ty)))
(declare (inline arr))
(magicl.lapack-cffi:%zgees
"V"
"N"
0
n
(magicl::storage aa)
n
0
tt-diag
(magicl::storage zz)
n
(arr lwork)
lwork
(arr n 'double-float)
(arr n '(signed-byte 32)) ; not referenced
info) ; INFO
;; TODO: we need to check info
(dotimes (i n)
(setf (magicl:tref tt i i) (aref tt-diag i)))
(values zz tt))))

5 changes: 5 additions & 0 deletions src/high-level/matrix.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,11 @@ See also: MAGICL:CSD"))
(define-extensible-function (svd svd-lisp) (matrix &key reduced)
(:documentation "Find the SVD of a matrix M. Return (VALUES U SIGMA Vt) where M = U @ SIGMA @ Vt"))

(define-backend-function schur (matrix1)
"Compute the Schur decomposition of a square matrix MATRIX. Return (VALUES ZZ TT) such that
MATRIX = ZZ @ TT @ ZZ*.")

(define-backend-function qz (matrix1 matrix2)
"Compute the QZ decomposition (aka the generalized Schur decomposition) on the pair of square matrices MATRIX1 and MATRIX2. Return (VALUES AA BB Q Z) such that
Expand Down
1 change: 1 addition & 0 deletions src/packages.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
#:csd-blocks
#:csd
#:svd
#:schur
#:qz
#:ql
#:qr
Expand Down
21 changes: 21 additions & 0 deletions tests/high-level-tests.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,27 @@
(is (magicl:= a (magicl:@ q aa (magicl:dagger z)) 1d-10))
(is (magicl:= b (magicl:@ q bb (magicl:dagger z)) 1d-10))))))

(deftest test-schur-real ()
"Test SCHUR on real matrices"
(loop :repeat 20 :do
(let ((a (magicl:random-hermitian 4)))
(multiple-value-bind (zz tt)
(magicl:schur a)
;; XXX: Multiplication is not yet defined for real*complex matrices.
(let ((zz-c (magicl:zeros (magicl:shape zz) :type `(complex ,(magicl:element-type zz))))
(a-c (magicl:zeros (magicl:shape a) :type `(complex ,(magicl:element-type a)))))
(magicl::map-to #'complex zz zz-c)
(magicl::map-to #'complex a a-c)
(is (magicl:= a-c (magicl:@ zz-c tt (magicl:transpose zz-c)))))))))

(deftest test-schur-complex ()
"Test SCHUR on complex matrices"
(loop :repeat 20 :do
(let ((a (magicl:random-unitary 4)))
(multiple-value-bind (zz tt)
(magicl:schur a)
(is (magicl:= a (magicl:@ zz tt (magicl:dagger zz))))))))

(deftest test-svd ()
"Test the full and reduced SVDs."
(labels ((mul-diag-times-gen (diag matrix)
Expand Down

0 comments on commit c2a9d9e

Please sign in to comment.