Giter Site home page Giter Site logo

Comments (1)

atdiar avatar atdiar commented on July 17, 2024
package main

import (
	"play.ground/cors"
)

func main() {
	//foo.Bar()
}
-- go.mod --
module play.ground
-- cors/cors.go --
// Package cors implements the server-side logic that is used in response
// to determine whether a response is allowed to be returned by the server across request origins for a given API route
package cors

import (
	"net/http"
	"net/textproto"
	"strconv"
	"strings"
	"time"
)

// Access control reference: https://www.w3.org/TR/cors/

/*
Rationale
=========

A Cross Origin http request for a given resource is made when a user agent is
used to retrieve resources from a  given domain which themselves depend on
resources from another domain (such as an image stored on a foreign  CDN for
instance).

The current package can be used to specify the conditions under which we allow a
resource at a given endpoint to be accessed.

The default being `same-origin` policy (same domain, same protocol, same port,
same host), it can be relaxed by specifying the type of Cross Origin request the
server allows (by Origin, by Headers, Content-type, etc.)

Hence, the presence of these headers determines whether a resource is accessible.
*/

var (
	// SimpleRequestMethods is the set of methods for which CORS is allowed
	// without preflight.
	SimpleRequestMethods = newSet().Add("GET", "HEAD", "POST")

	// SimpleRequestHeaders is the set of headers for which CORS is allowed
	// without preflight.
	SimpleRequestHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type")

	// SimpleRequestContentTypes is the set of headers for which CORS is allowed
	// without preflight.
	SimpleRequestContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")

	// SimpleResponseHeaders is the set of header field names for which CORS is
	// allows a response to a request without preflight.
	SimpleResponseHeaders = newSet().Add("Cache-Control", "Content-Language", "Content-Type", "Expires", "Last-Modified", "Pragma")
)

// PolicyStore holds the different CORS policies per API route.
// It applies to incoming http requests.
// CORS controls the access to resources available on the server by defining
// constraints (request origin, http methods allowed, headers allowed, etc.)
type PolicyStore struct {
	Policies map[string]Policy // each route can be given a specific policy which the preflight and the given route Handler can refer to
}

// New returns an object that holds the CORS policy for the different CORS enabled routes.

func New(s *http.ServeMux) *PolicyStore{
	var p PolicyStore
	p.Policies = make(map[string]Policy)
	return &p
}

func(p*PolicyStore) New(path string, rules Policy) 

// Policy is used to define a CORS
// response to a Cross-Origin request for a given resource.
// "*" is used to denote that anything is accepted (resp. Headers, Methods,
// Content-Types).
// The fields AllowedOrigins, AllowedHeaders, AllowedMethods, ExposeHeaders and
// AllowedContentTypes are sets of strings. A string may be inserted by using
// the `Add(str string, caseSensitive bool)` method.
// It is also possible to lookup for the existence of a string within a set
// thanks to the `Contains(str string, caseSensitive bool)` method.
type Policy struct {
	AllowedOrigins      set
	AllowedHeaders      set
	AllowedContentTypes set
	ExposeHeaders       set
	AllowedMethods      set
	AllowCredentials    bool
	MaxAge time.Duration // for preflight config
}


type preflightHandler struct {
	*PolicyStore
}

// MaxAge sets a limit to the validity of a preflight result in
// cache.
func (p *preflightHandler) MaxAge(t time.Duration) {
	// Implementation which should set the Access-Control-Max-Age header in sec.
	// (in the allowed headers)
	p.PolicyStore.AllowedHeaders.Add("Access-Control-Max-Age")
	p.PolicyStore.MaxAge = t

}

// NewHandler creates a new, CORS policy enforcing, request handler.
func NewHandler() Handler {
	h := Handler{}
	h.Policy = new(Policy )
	h.Policy .AllowedOrigins = newSet()
	h.Policy .AllowedHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type", "Origin")
	h.Policy .AllowedContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")
	h.Policy .ExposeHeaders = newSet()
	h.Policy .AllowedMethods = newSet()
	return h
}


func (p *preflightHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

	// Check Headers: Origin, Access-Control-Request-Method, Access-Control-Request-Headers
	if !originHeaderIsPresent(r) {
		http.Error(w, "origin header is absent", 403)
		return
	}

	// The preflight request is a preparation step that verifies that the request
	// observes the requirement from the server in terms of origin, method, headers

	// Checking origin
	w.Header().Add("Vary", "Origin")

	origin, ok := (textproto.MIMEHeader(r.Header))["Origin"]
	if !ok {
		http.Error(w, "origin header is absent or malformed", 403)
		return
	}
	originallowed := p.PolicyStore.AllowedOrigins.Contains(origin[0], true)
	if p.Parameters.AllowedOrigins.Contains("*", false) {
		originallowed = true
	}
	if !originallowed {
		http.Error(w, "origin not allowed", 403)
		return
	}

	// Checking method
	w.Header().Add("Vary", "Access-Control-Request-Method")

	method, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Method"]
	if !ok {
		http.Error(w, "method header absent", 403)
		return
	}
	methodallowed := p.Parameters.AllowedMethods.Contains(method[0], true)
	if p.Parameters.AllowedMethods.Contains("*", true) {
		methodallowed = true
	}
	if !methodallowed {
		http.Error(w, "method not allowed", 403)
		return
	}

	// Checking headers
	w.Header().Add("Vary", "Access-Control-Request-Headers")

	headers, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Headers"]
	if !ok {
		http.Error(w, "access control headers missing", 403)
		return
	}

	headersallowed := p.Parameters.AllowedHeaders.Contains(headers[0], false)
	for _, header := range headers {
		headersallowed = headersallowed && p.Parameters.AllowedHeaders.Contains(header, false)
	}
	if p.Parameters.AllowedHeaders.Contains("*", false) {
		headersallowed = true
	}
	if !headersallowed {
		http.Error(w, "unallowed headers present", 403)
		return
	}

	// Setting the appropriate Headers on the HTTP response
	setAllowCredentials(w, p.Parameters.AllowCredentials)

	if p.MaxAge != 0 {
		setMaxAge(w, int(p.MaxAge.Seconds()))
	}

	w.Header().Set("Access-Control-Allow-Methods", method[0])
	for _, header := range headers {
		w.Header().Add("Access-Control-Allow-Headers", header)
	}
}

// WithCredentials will allow the emmission of cookies, authorization headers,
// TLS client certificates with the http requests by the client.
func (h Handler) WithCredentials() Handler {
	h.Parameters.AllowCredentials = true
	return h
}

func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	w.Header().Add("Vary", "Origin")

	if !originIsPresent(r) {
		if h.next != nil {
			h.next.ServeHTTP(w, r)
		}
		return
	}

	// if the request is a simple one, we do not need to do much.
	if methodIsAllowed(r, SimpleRequestMethods) {
		if headersAreAllowed(r, SimpleRequestHeaders) {
			if contentTypeIsAllowed(r, SimpleRequestContentTypes) {
				if h.next != nil {
					h.next.ServeHTTP(w, r)
				}
				return
			}
		}
	}
	setAllowOrigin(w, r, h.Parameters.AllowedOrigins)
	setAllowCredentials(w, h.Parameters.AllowCredentials)
	setExposeHeaders(w, h.Parameters.ExposeHeaders)

	if h.next != nil {
		h.next.ServeHTTP(w, r)
	}
}


// setAllowOrigin will write the Access-Control-Allow-Origin header assigning to
// it the correct value.
func setAllowOrigin(w http.ResponseWriter, r *http.Request, AllowedOrigins set) {
	header := textproto.MIMEHeader(r.Header)
	origin, ok := header["Origin"]
	if !ok {
		return
	}

	if len(origin) != 1 {
		return
	}

	ori := origin[0]

	if !AllowedOrigins.Contains(ori, true) {
		if AllowedOrigins.Contains("*", true) {
			w.Header().Set("Access-Control-Allow-Origin", ori)
			return
		}

		w.Header().Set("Access-Control-Allow-Origin", "null")
		return
	}

	w.Header().Set("Access-Control-Allow-Origin", ori)

}

// setAllowMethods will write the Access-Control-Allow-Methods header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of methods that can be used in the actual
// request.
func setAllowMethods(w http.ResponseWriter, s set) {
	for method := range s {
		w.Header().Add("Access-Control-Allow-Methods", method)
	}
}

// setAllowHeaders will write the Access-Control-Allow-Headers header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of headers that can be used in the actual
// request.
func setAllowHeaders(w http.ResponseWriter, s set) {
	for header := range s {
		w.Header().Add("Access-Control-Allow-Headers", header)
	}
}

// setExposeHeaders writes out the Access-Control-Expose-Headers header.
// This is merely a whitelist of headers that the user-agent can read from an
// http response to a CORS request.
func setExposeHeaders(w http.ResponseWriter, s set) {
	for header := range s {
		w.Header().Add("Access-Control-Expose-Headers", header)
	}
}

// setAllowCredentials writes out the Access-Control-Allow-Credentials header which
// indicates whether the actual request can include user credentials (in the
// case of a preflighted request).
// Otherwise (no preflight), it indicates whether the response can be exposed.
//
// NOTE: Note sure it will be that useful since the Basic Authenitcation scheme
// of the http protocol is not very practical.
func setAllowCredentials(w http.ResponseWriter, b bool) {
	if b {
		w.Header().Set("Access-Control-Allow-Credentials", "true")
		return
	}
	w.Header().Set("Access-Control-Allow-Credentials", "false")
}

// setMaxAge writes out the Access-Control-Max-Age header which indicates for
// how long the results of the preflight request can be cached by the user-agent
// (browser for instance)
func setMaxAge(w http.ResponseWriter, seconds int) {
	w.Header().Set("Access-Control-Max-Age", strconv.Itoa(seconds))
}

func headersAreAllowed(r *http.Request, s set) bool {
	for k := range r.Header {
		if !s.Contains(k, false) {
			return false
		}
	}
	return true
}

func methodIsAllowed(r *http.Request, s set) bool {
	return s.Contains(r.Method, true)
}

func contentTypeIsAllowed(r *http.Request, s set) bool {
	h := textproto.MIMEHeader(r.Header)
	ct := h["Content-Type"]
	var res bool
	for _, val := range ct {
		res = res && s.Contains(val, false)
	}
	return res
}

func originHeaderIsPresent(req *http.Request) bool {
	ori := textproto.MIMEHeader(req.Header).Get("Origin")
	if ori != "" {
		return true
	}
	return false
}

// set defines an unordered list of string elements.
// Two methods have been made available:
// - an insert method called `Add`
// - a delete method called `Remove`
// - a lookup method called `Contains`
type set map[string]struct{}

func newSet() set {
	s := make(map[string]struct{})
	return s
}

func (s set) Add(strls ...string) set {
	for _, str := range strls {
		s[str] = struct{}{}
	}
	return s
}

func (s set) Remove(str string, caseSensitive bool) {
	if !caseSensitive {
		str = strings.ToLower(str)
	}
	delete(s, str)
}

func (s set) Contains(str string, caseSensitive bool) bool {
	if !caseSensitive {
		str = strings.ToLower(str)
	}
	for k := range s {
		if k == str {
			return true
		}
	}
	return false
}

from xhttp.

Related Issues (1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.