packer-cn/packer/plugin-getter/github/getter.go

233 lines
5.9 KiB
Go

package github
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/google/go-github/v33/github"
plugingetter "github.com/hashicorp/packer/packer/plugin-getter"
"golang.org/x/oauth2"
)
const (
ghTokenAccessor = "PACKER_GITHUB_API_TOKEN"
defaultUserAgent = "packer-plugin-getter"
defaultHostname = "github.com"
)
type Getter struct {
Client *github.Client
UserAgent string
}
var _ plugingetter.Getter = &Getter{}
func tranformChecksumStream() func(in io.ReadCloser) (io.ReadCloser, error) {
return func(in io.ReadCloser) (io.ReadCloser, error) {
defer in.Close()
rd := bufio.NewReader(in)
buffer := bytes.NewBufferString("[")
json := json.NewEncoder(buffer)
for i := 0; ; i++ {
line, err := rd.ReadString('\n')
if err != nil {
if err != io.EOF {
return nil, fmt.Errorf(
"Error reading checksum file: %s", err)
}
break
}
parts := strings.Fields(line)
switch len(parts) {
case 2: // nominal case
checksumString, checksumFilename := parts[0], parts[1]
if i > 0 {
_, _ = buffer.WriteString(",")
}
if err := json.Encode(struct {
Checksum string `json:"checksum"`
Filename string `json:"filename"`
}{
Checksum: checksumString,
Filename: checksumFilename,
}); err != nil {
return nil, err
}
}
}
_, _ = buffer.WriteString("]")
return ioutil.NopCloser(buffer), nil
}
}
// transformVersionStream get a stream from github tags and transforms it into
// something Packer wants, namely a json list of Release.
func transformVersionStream(in io.ReadCloser) (io.ReadCloser, error) {
if in == nil {
return nil, fmt.Errorf("transformVersionStream got nil body")
}
defer in.Close()
dec := json.NewDecoder(in)
m := []struct {
Ref string `json:"ref"`
}{}
if err := dec.Decode(&m); err != nil {
return nil, err
}
out := []plugingetter.Release{}
for _, m := range m {
out = append(out, plugingetter.Release{
Version: strings.TrimPrefix(m.Ref, "refs/tags/"),
})
}
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(out); err != nil {
return nil, err
}
return ioutil.NopCloser(buf), nil
}
// HostSpecificTokenAuthTransport makes sure the http roundtripper only sets an
// auth token for requests aimed at a specific host.
//
// This helps for example to get release files from Github as Github will
// redirect to s3 which will error if we give it a Github auth token.
type HostSpecificTokenAuthTransport struct {
// Host to TokenSource map
TokenSources map[string]oauth2.TokenSource
// actual RoundTripper, nil means we use the default one from http.
Base http.RoundTripper
}
// RoundTrip authorizes and authenticates the request with an
// access token from Transport's Source.
func (t *HostSpecificTokenAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
source, found := t.TokenSources[req.Host]
if found {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}
if source == nil {
return nil, errors.New("transport's Source is nil")
}
token, err := source.Token()
if err != nil {
return nil, err
}
token.SetAuthHeader(req)
// req.Body is assumed to be closed by the base RoundTripper.
reqBodyClosed = true
}
return t.base().RoundTrip(req)
}
func (t *HostSpecificTokenAuthTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func (g *Getter) Get(what string, opts plugingetter.GetOptions) (io.ReadCloser, error) {
if opts.PluginRequirement.Identifier.Hostname != defaultHostname {
s := opts.PluginRequirement.Identifier.String() + " doesn't appear to be a valid " + defaultHostname + " source address; check source and try again."
return nil, errors.New(s)
}
ctx := context.TODO()
if g.Client == nil {
var tc *http.Client
if tk := os.Getenv(ghTokenAccessor); tk != "" {
log.Printf("[DEBUG] github-getter: using %s", ghTokenAccessor)
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: tk},
)
tc = &http.Client{
Transport: &HostSpecificTokenAuthTransport{
TokenSources: map[string]oauth2.TokenSource{
"api.github.com": ts,
},
},
}
}
g.Client = github.NewClient(tc)
g.Client.UserAgent = defaultUserAgent
if g.UserAgent != "" {
g.Client.UserAgent = g.UserAgent
}
}
var req *http.Request
var err error
transform := func(in io.ReadCloser) (io.ReadCloser, error) {
return in, nil
}
switch what {
case "releases":
u := filepath.ToSlash("/repos/" + opts.PluginRequirement.Identifier.RealRelativePath() + "/git/matching-refs/tags")
req, err = g.Client.NewRequest("GET", u, nil)
transform = transformVersionStream
case "sha256":
// something like https://github.com/sylviamoss/packer-plugin-comment/releases/download/v0.2.11/packer-plugin-comment_v0.2.11_x5_SHA256SUMS
u := filepath.ToSlash("https://github.com/" + opts.PluginRequirement.Identifier.RealRelativePath() + "/releases/download/" + opts.Version() + "/" + opts.PluginRequirement.FilenamePrefix() + opts.Version() + "_SHA256SUMS")
req, err = g.Client.NewRequest(
"GET",
u,
nil,
)
transform = tranformChecksumStream()
case "zip":
u := filepath.ToSlash("https://github.com/" + opts.PluginRequirement.Identifier.RealRelativePath() + "/releases/download/" + opts.Version() + "/" + opts.ExpectedZipFilename())
req, err = g.Client.NewRequest(
"GET",
u,
nil,
)
default:
return nil, fmt.Errorf("%q not implemented", what)
}
if err != nil {
return nil, err
}
log.Printf("[DEBUG] github-getter: getting %q", req.URL)
resp, err := g.Client.BareDo(ctx, req)
if err != nil {
// here BareDo will return an err if the request failed or if the
// status is not considered a valid http status.
if resp != nil {
resp.Body.Close()
}
return nil, err
}
return transform(resp.Body)
}