Merge branch 'main' of ssh://git.bullercodeworks.com:2200/brian/expds
This commit is contained in:
@@ -27,7 +27,7 @@ func (a *AppLogHandler) addGroups(groups ...string) { a.groups = append(a.group
|
||||
// AppLogHandler can handle all levels
|
||||
func (a *AppLogHandler) Enabled(_ context.Context, lvl slog.Level) bool { return lvl >= a.level }
|
||||
|
||||
func (a *AppLogHandler) Handle(ctx context.Context, rcd slog.Record) error {
|
||||
func (a *AppLogHandler) Handle(_ context.Context, rcd slog.Record) error {
|
||||
if a.logFunc == nil {
|
||||
return errors.New("no log func defined")
|
||||
}
|
||||
|
||||
28
data/repo.go
28
data/repo.go
@@ -40,6 +40,7 @@ type Repo struct {
|
||||
|
||||
handler *AppLogHandler
|
||||
logFunc func(string, ...any)
|
||||
Logger *slog.Logger
|
||||
|
||||
context context.Context
|
||||
}
|
||||
@@ -56,6 +57,7 @@ func NewRepo() (*Repo, error) {
|
||||
} else {
|
||||
r.handler.SetLevel(slog.LevelWarn)
|
||||
}
|
||||
r.Logger = slog.Default()
|
||||
var err error
|
||||
r.Auth, err = NewAuthRepo(r)
|
||||
if err != nil {
|
||||
@@ -64,10 +66,7 @@ func NewRepo() (*Repo, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Repo) GetPDS(atId string) (*models.Pds, error) {
|
||||
if p, ok := r.LoadedPDSs[atId]; ok && time.Since(p.RefreshTime) < r.BestBy {
|
||||
return p, nil
|
||||
}
|
||||
func (r *Repo) fetchPds(atId string) (*models.Pds, error) {
|
||||
p, err := models.NewPdsFromDid(atId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -76,8 +75,19 @@ func (r *Repo) GetPDS(atId string) (*models.Pds, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (r *Repo) SendToPDS() error {
|
||||
session, err := r.Auth.GetSession()
|
||||
func (r *Repo) ReloadPds(atId string) (*models.Pds, error) {
|
||||
return r.fetchPds(atId)
|
||||
}
|
||||
|
||||
func (r *Repo) GetPDS(atId string) (*models.Pds, error) {
|
||||
if p, ok := r.LoadedPDSs[atId]; ok && time.Since(p.RefreshTime) < r.BestBy {
|
||||
return p, nil
|
||||
}
|
||||
return r.fetchPds(atId)
|
||||
}
|
||||
|
||||
func (r *Repo) SendToPDS(did syntax.DID) error {
|
||||
session, err := r.Auth.GetSession(did)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,11 +104,11 @@ func (r *Repo) SendToPDS() error {
|
||||
var resp struct {
|
||||
Uri syntax.ATURI `json:"uri"`
|
||||
}
|
||||
slog.Debug("posting expds status...")
|
||||
r.Logger.Debug("posting expds status...")
|
||||
if err := c.Post(r.context, "com.atproto.repo.CreateRecord", body, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("posted: %s :: %s", resp.Uri.Authority(), resp.Uri.RecordKey()))
|
||||
r.Logger.Debug(fmt.Sprintf("posted: %s :: %s", resp.Uri.Authority(), resp.Uri.RecordKey()))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -106,4 +116,6 @@ func (r *Repo) SetLogFunc(l func(string, ...any)) {
|
||||
r.logFunc = l
|
||||
r.handler = NewAppLogHandler(r.logFunc)
|
||||
slog.SetDefault(slog.New(r.handler))
|
||||
r.Logger = slog.Default()
|
||||
r.Logger.Debug("New Log Func Set for slog")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@@ -25,12 +27,15 @@ type AuthRepo struct {
|
||||
|
||||
session *oauth.ClientSessionData
|
||||
authError error
|
||||
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewAuthRepo(r *Repo) (*AuthRepo, error) {
|
||||
a := &AuthRepo{
|
||||
r: r,
|
||||
context: r.context,
|
||||
Logger: r.Logger,
|
||||
}
|
||||
var err error
|
||||
a.oauthConfig, a.oauthClient, a.store, err = a.buildOAuthClient()
|
||||
@@ -109,20 +114,22 @@ func (r *AuthRepo) ListenForCallback(res chan url.Values) (int, error) {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
panic(err)
|
||||
}
|
||||
r.Logger.Debug("Server Shut Down")
|
||||
}()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) HasAuth() bool {
|
||||
sess, err := r.store.GetMostRecentSession(r.context)
|
||||
return err != nil || sess == nil
|
||||
func (r *AuthRepo) HasAuth(did syntax.DID) bool {
|
||||
sess, err := r.GetSession(did)
|
||||
return err == nil && sess != nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetSession() (*oauth.ClientSession, error) {
|
||||
sess, err := r.store.GetMostRecentSession(r.context)
|
||||
func (r *AuthRepo) GetSession(did syntax.DID) (*oauth.ClientSession, error) {
|
||||
sess, err := r.store.GetMostRecentSessionFor(r.context, did)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting most recent session: %w", err)
|
||||
}
|
||||
r.Logger.Warn(fmt.Sprintf("GetSession(): Resuming Session: %s (%s)", sess.SessionID, sess.AccountDID))
|
||||
return r.oauthClient.ResumeSession(r.context, sess.AccountDID, sess.SessionID)
|
||||
}
|
||||
|
||||
139
data/repo_auth.go.orig
Normal file
139
data/repo_auth.go.orig
Normal file
@@ -0,0 +1,139 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bluesky-social/indigo/atproto/auth/oauth"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type AuthRepo struct {
|
||||
r *Repo
|
||||
oauthClient *oauth.ClientApp
|
||||
oauthConfig *oauth.ClientConfig
|
||||
store *SqliteStore
|
||||
context context.Context
|
||||
|
||||
<<<<<<< HEAD
|
||||
session *oauth.ClientSessionData
|
||||
authError error
|
||||
=======
|
||||
session *oauth.ClientSessionData
|
||||
|
||||
Logger *slog.Logger
|
||||
>>>>>>> 367d62ff009186e5aa584fd069dd57aaf7d46a8a
|
||||
}
|
||||
|
||||
func NewAuthRepo(r *Repo) (*AuthRepo, error) {
|
||||
a := &AuthRepo{
|
||||
r: r,
|
||||
context: r.context,
|
||||
Logger: r.Logger,
|
||||
}
|
||||
var err error
|
||||
a.oauthConfig, a.oauthClient, a.store, err = a.buildOAuthClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Build the OAuthClient connected to our sqlite db
|
||||
func (r *AuthRepo) buildOAuthClient() (*oauth.ClientConfig, *oauth.ClientApp, *SqliteStore, error) {
|
||||
config := oauth.ClientConfig{
|
||||
ClientID: "https://expds.bullercodeworks.com/oauth-client-metadata.json",
|
||||
Scopes: []string{"atproto", "repo:*", "blob:*/*"},
|
||||
UserAgent: "expds",
|
||||
}
|
||||
|
||||
store, err := NewSqliteStore(&SqliteStoreConfig{
|
||||
DatabasePath: r.prepareDbPath(),
|
||||
SessionExpiryDuration: time.Hour * 24 * 90,
|
||||
SessionInactivityDuration: time.Hour * 24 * 14,
|
||||
AuthRequestExpiryDuration: time.Minute * 30,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
oauthClient := oauth.NewClientApp(&config, store)
|
||||
return &config, oauthClient, store, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) StartAuthFlow(port int, identifier string, callbackRes chan url.Values) (string, error) {
|
||||
r.oauthConfig.CallbackURL = fmt.Sprintf("http://127.0.0.1:%d/callback", port)
|
||||
authUrl, err := r.oauthClient.StartAuthFlow(r.context, identifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error logging in: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(authUrl, "https://") {
|
||||
return "", fmt.Errorf("non-https authUrl")
|
||||
}
|
||||
go func() {
|
||||
exec.Command("xdg-open", authUrl).Start()
|
||||
r.session, r.authError = r.oauthClient.ProcessCallback(r.context, <-callbackRes)
|
||||
}()
|
||||
return authUrl, nil
|
||||
}
|
||||
|
||||
// Follows XDG conventions and creates the directories if necessary.
|
||||
// By default, on linux, this will be "~/.local/share/go-oauth-cli-app/oauth_sessions.sqlite3"
|
||||
func (r *AuthRepo) prepareDbPath() string {
|
||||
return filepath.Join(viper.GetString(KeyDataDir), "expds.sqlite3")
|
||||
}
|
||||
|
||||
// HTTP Server listening for OAuth Response
|
||||
func (r *AuthRepo) ListenForCallback(res chan url.Values) (int, error) {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/callback", func(w http.ResponseWriter, req *http.Request) {
|
||||
res <- req.URL.Query()
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("<!DOCTYPE html><html><body><h2>expds</h2><p>You can safely close this window and return to your application.</p></body></html>\n"))
|
||||
go server.Shutdown(r.context)
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
panic(err)
|
||||
}
|
||||
r.Logger.Debug("Server Shut Down")
|
||||
}()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) HasAuth(did syntax.DID) bool {
|
||||
sess, err := r.GetSession(did)
|
||||
return err == nil && sess != nil
|
||||
}
|
||||
|
||||
func (r *AuthRepo) GetSession(did syntax.DID) (*oauth.ClientSession, error) {
|
||||
sess, err := r.store.GetMostRecentSessionFor(r.context, did)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting most recent session: %w", err)
|
||||
}
|
||||
r.Logger.Warn(fmt.Sprintf("GetSession(): Resuming Session: %s (%s)", sess.SessionID, sess.AccountDID))
|
||||
return r.oauthClient.ResumeSession(r.context, sess.AccountDID, sess.SessionID)
|
||||
}
|
||||
@@ -96,6 +96,20 @@ func (m *SqliteStore) GetSession(ctx context.Context, did syntax.DID, sessionID
|
||||
return &row.Data, nil
|
||||
}
|
||||
|
||||
func (m *SqliteStore) GetMostRecentSessionFor(ctx context.Context, did syntax.DID) (*oauth.ClientSessionData, error) {
|
||||
var row storedSessionData
|
||||
res := m.db.WithContext(ctx).Where(&storedSessionData{
|
||||
AccountDid: did,
|
||||
}).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Name: "updated_at"},
|
||||
Desc: true,
|
||||
}).First(&row)
|
||||
if res.Error != nil {
|
||||
return nil, res.Error
|
||||
}
|
||||
return &row.Data, nil
|
||||
}
|
||||
|
||||
// not part of the ClientAuthStore interface, just used for the CLI app
|
||||
func (m *SqliteStore) GetMostRecentSession(ctx context.Context) (*oauth.ClientSessionData, error) {
|
||||
var row storedSessionData
|
||||
|
||||
Reference in New Issue
Block a user