136 lines
3.7 KiB
Go
136 lines
3.7 KiB
Go
package data
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
|
"github.com/bluesky-social/indigo/atproto/syntax"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type AuthRepo struct {
|
|
r *Repo
|
|
oauthClient *oauth.ClientApp
|
|
oauthConfig *oauth.ClientConfig
|
|
store *SqliteStore
|
|
context context.Context
|
|
|
|
session *oauth.ClientSessionData
|
|
authError error
|
|
|
|
Logger *slog.Logger
|
|
}
|
|
|
|
func NewAuthRepo(r *Repo) (*AuthRepo, error) {
|
|
a := &AuthRepo{
|
|
r: r,
|
|
context: r.context,
|
|
Logger: r.Logger,
|
|
}
|
|
var err error
|
|
a.oauthConfig, a.oauthClient, a.store, err = a.buildOAuthClient()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
// Build the OAuthClient connected to our sqlite db
|
|
func (r *AuthRepo) buildOAuthClient() (*oauth.ClientConfig, *oauth.ClientApp, *SqliteStore, error) {
|
|
config := oauth.ClientConfig{
|
|
ClientID: "https://expds.bullercodeworks.com/oauth-client-metadata.json",
|
|
Scopes: []string{"atproto", "repo:*", "blob:*/*"},
|
|
UserAgent: "expds",
|
|
}
|
|
|
|
store, err := NewSqliteStore(&SqliteStoreConfig{
|
|
DatabasePath: r.prepareDbPath(),
|
|
SessionExpiryDuration: time.Hour * 24 * 90,
|
|
SessionInactivityDuration: time.Hour * 24 * 14,
|
|
AuthRequestExpiryDuration: time.Minute * 30,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
oauthClient := oauth.NewClientApp(&config, store)
|
|
return &config, oauthClient, store, nil
|
|
}
|
|
|
|
func (r *AuthRepo) StartAuthFlow(port int, identifier string, callbackRes chan url.Values) (string, error) {
|
|
r.oauthConfig.CallbackURL = fmt.Sprintf("http://127.0.0.1:%d/callback", port)
|
|
authUrl, err := r.oauthClient.StartAuthFlow(r.context, identifier)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error logging in: %w", err)
|
|
}
|
|
if !strings.HasPrefix(authUrl, "https://") {
|
|
return "", fmt.Errorf("non-https authUrl")
|
|
}
|
|
go func() {
|
|
exec.Command("xdg-open", authUrl).Start()
|
|
r.session, r.authError = r.oauthClient.ProcessCallback(r.context, <-callbackRes)
|
|
}()
|
|
return authUrl, nil
|
|
}
|
|
|
|
// Follows XDG conventions and creates the directories if necessary.
|
|
// By default, on linux, this will be "~/.local/share/go-oauth-cli-app/oauth_sessions.sqlite3"
|
|
func (r *AuthRepo) prepareDbPath() string {
|
|
return filepath.Join(viper.GetString(KeyDataDir), "expds.sqlite3")
|
|
}
|
|
|
|
// HTTP Server listening for OAuth Response
|
|
func (r *AuthRepo) ListenForCallback(res chan url.Values) (int, error) {
|
|
listener, err := net.Listen("tcp", ":0")
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
server := &http.Server{
|
|
Handler: mux,
|
|
}
|
|
|
|
mux.HandleFunc("/callback", func(w http.ResponseWriter, req *http.Request) {
|
|
res <- req.URL.Query()
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
w.WriteHeader(200)
|
|
w.Write([]byte("<!DOCTYPE html><html><body><h2>expds</h2><p>You can safely close this window and return to your application.</p></body></html>\n"))
|
|
go server.Shutdown(r.context)
|
|
})
|
|
|
|
go func() {
|
|
err := server.Serve(listener)
|
|
if !errors.Is(err, http.ErrServerClosed) {
|
|
panic(err)
|
|
}
|
|
r.Logger.Debug("Server Shut Down")
|
|
}()
|
|
|
|
return listener.Addr().(*net.TCPAddr).Port, nil
|
|
}
|
|
|
|
func (r *AuthRepo) HasAuth(did syntax.DID) bool {
|
|
sess, err := r.GetSession(did)
|
|
return err == nil && sess != nil
|
|
}
|
|
|
|
func (r *AuthRepo) GetSession(did syntax.DID) (*oauth.ClientSession, error) {
|
|
sess, err := r.store.GetMostRecentSessionFor(r.context, did)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting most recent session: %w", err)
|
|
}
|
|
r.Logger.Warn(fmt.Sprintf("GetSession(): Resuming Session: %s (%s)", sess.SessionID, sess.AccountDID))
|
|
return r.oauthClient.ResumeSession(r.context, sess.AccountDID, sess.SessionID)
|
|
}
|