{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- A middleware to respond to Options requests for a servant app
-- very helpful when trying to deal with pre-flight CORS requests.
--
module Network.Wai.Middleware.Servant.Options (provideOptions) where

import Servant
import Servant.Foreign
import Network.Wai
import Data.Text hiding (null, zipWith, length)
import Network.HTTP.Types.Method
import Data.Maybe
import Data.List (nub)
import Network.HTTP.Types
import qualified Data.ByteString as B

provideOptions :: (GenerateList NoContent (Foreign NoContent api), HasForeign NoTypes NoContent api)
               => Proxy api -> Middleware
provideOptions :: Proxy api -> Middleware
provideOptions apiproxy :: Proxy api
apiproxy app :: Application
app req :: Request
req cb :: Response -> IO ResponseReceived
cb
  | Method
rmeth Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== "OPTIONS" = (Response -> IO ResponseReceived)
-> IO ResponseReceived
-> [Text]
-> [Req NoContent]
-> IO ResponseReceived
forall r. (Response -> r) -> r -> [Text] -> [Req NoContent] -> r
optional Response -> IO ResponseReceived
cb IO ResponseReceived
prior [Text]
pinfo [Req NoContent]
mlist
  | Bool
otherwise          = IO ResponseReceived
prior
  where
  rmeth :: Method
rmeth = Request -> Method
requestMethod Request
req :: Method
  pinfo :: [Text]
pinfo = Request -> [Text]
pathInfo      Request
req :: [ Text ]
  mlist :: [Req NoContent]
mlist = Proxy NoTypes -> Proxy NoContent -> Proxy api -> [Req NoContent]
forall k (lang :: k) ftype api.
(HasForeign lang ftype api,
 GenerateList ftype (Foreign ftype api)) =>
Proxy lang -> Proxy ftype -> Proxy api -> [Req ftype]
listFromAPI (Proxy NoTypes
forall k (t :: k). Proxy t
Proxy :: Proxy NoTypes) (Proxy NoContent
forall k (t :: k). Proxy t
Proxy :: Proxy NoContent) Proxy api
apiproxy
  prior :: IO ResponseReceived
prior = Application
app Request
req Response -> IO ResponseReceived
cb

optional :: (Response -> r) -> r -> [Text] -> [Req NoContent] -> r
optional :: (Response -> r) -> r -> [Text] -> [Req NoContent] -> r
optional cb :: Response -> r
cb prior :: r
prior ts :: [Text]
ts rs :: [Req NoContent]
rs
  | [Method] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Method]
methods = r
prior
  | Bool
otherwise    = Response -> r
cb ([Method] -> Response
buildResponse [Method]
methods)
  where
  methods :: [Method]
methods = (Req NoContent -> Maybe Method) -> [Req NoContent] -> [Method]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ([Text] -> Req NoContent -> Maybe Method
getMethod [Text]
ts) [Req NoContent]
rs

getMethod :: [Text] -> Req NoContent -> Maybe Method
getMethod :: [Text] -> Req NoContent -> Maybe Method
getMethod rs :: [Text]
rs ps :: Req NoContent
ps
  | Bool
sameLength Bool -> Bool -> Bool
&& Bool
matchingSegments = Method -> Maybe Method
forall a. a -> Maybe a
Just (Req NoContent -> Method
forall f. Req f -> Method
_reqMethod Req NoContent
ps)
  | Bool
otherwise                      = Maybe Method
forall a. Maybe a
Nothing
  where
  pattern :: Path NoContent
pattern          = Url NoContent -> Path NoContent
forall f. Url f -> Path f
_path (Url NoContent -> Path NoContent)
-> Url NoContent -> Path NoContent
forall a b. (a -> b) -> a -> b
$ Req NoContent -> Url NoContent
forall f. Req f -> Url f
_reqUrl Req NoContent
ps
  sameLength :: Bool
sameLength       = [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
rs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Path NoContent -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Path NoContent
pattern
  matchingSegments :: Bool
matchingSegments = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (Text -> Segment NoContent -> Bool)
-> [Text] -> Path NoContent -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Text -> Segment NoContent -> Bool
matchSegment [Text]
rs Path NoContent
pattern

matchSegment :: Text -> Segment NoContent -> Bool
matchSegment :: Text -> Segment NoContent -> Bool
matchSegment a :: Text
a (Segment (Static (PathSegment b :: Text
b)) ) | Text
a Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
b = Bool
False
matchSegment _ _                                            = Bool
True

buildResponse :: [Method] -> Response
buildResponse :: [Method] -> Response
buildResponse ms :: [Method]
ms = Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
s ResponseHeaders
h Builder
forall a. Monoid a => a
mempty
  where
  s :: Status
s = Int -> Method -> Status
Status 200 "OK"
  m :: Method
m = Method -> [Method] -> Method
B.intercalate ", " ("OPTIONS" Method -> [Method] -> [Method]
forall a. a -> [a] -> [a]
: [Method] -> [Method]
forall a. Eq a => [a] -> [a]
nub [Method]
ms)
  h :: ResponseHeaders
h = [ ("Allow", Method
m) ]