diff --git a/cache/cache.go b/cache/cache.go index f43aeaf..57d4304 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -2,9 +2,9 @@ package cache import ( "fmt" + "log/slog" "time" - "github.com/apex/log" "github.com/sparetimecoders/goamqp" "gitlab.com/unboundsoftware/schemas/domain" @@ -18,7 +18,7 @@ type Cache struct { services map[string]map[string]map[string]struct{} subGraphs map[string]string lastUpdate map[string]string - logger log.Interface + logger *slog.Logger } func (c *Cache) OrganizationByAPIKey(apiKey string) *domain.Organization { @@ -98,7 +98,7 @@ func (c *Cache) Update(msg any, _ goamqp.Headers) (any, error) { case *domain.SubGraph: c.updateSubGraph(m.OrganizationId, m.Ref, m.ID.String(), m.Service, m.ChangedAt) default: - c.logger.Warnf("unexpected message received: %+v", msg) + c.logger.With("msg", msg).Warn("unexpected message received") } return nil, nil } @@ -124,7 +124,7 @@ func (c *Cache) addUser(sub string, organization domain.Organization) { } } -func New(logger log.Interface) *Cache { +func New(logger *slog.Logger) *Cache { return &Cache{ organizations: make(map[string]domain.Organization), users: make(map[string][]string), diff --git a/cmd/service/service.go b/cmd/service/service.go index b7169ff..79aa232 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -2,7 +2,9 @@ package main import ( "context" + "errors" "fmt" + "log/slog" "net/http" "os" "os/signal" @@ -17,8 +19,6 @@ import ( "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" "github.com/alecthomas/kong" - "github.com/apex/log" - "github.com/apex/log/handlers/json" "github.com/getsentry/sentry-go" sentryhttp "github.com/getsentry/sentry-go/http" "github.com/rs/cors" @@ -32,6 +32,7 @@ import ( "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/graph" "gitlab.com/unboundsoftware/schemas/graph/generated" + "gitlab.com/unboundsoftware/schemas/logging" "gitlab.com/unboundsoftware/schemas/middleware" "gitlab.com/unboundsoftware/schemas/store" ) @@ -59,9 +60,7 @@ const serviceName = "schemas" func main() { var cli CLI _ = kong.Parse(&cli) - log.SetHandler(json.New(os.Stdout)) - log.SetLevelFromString(cli.LogLevel) - logger := log.WithField("service", serviceName) + logger := logging.SetupLogger(cli.LogLevel, serviceName, buildVersion) closeEvents := make(chan error) if err := start( @@ -70,11 +69,11 @@ func main() { ConnectAMQP, cli, ); err != nil { - logger.WithError(err).Error("process error") + logger.With("error", err).Error("process error") } } -func start(closeEvents chan error, logger *log.Entry, connectToAmqpFunc func(url string) (Connection, error), cli CLI) error { +func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(url string) (Connection, error), cli CLI) error { if err := setupSentry(logger, cli.SentryConfig); err != nil { return err } @@ -123,7 +122,7 @@ func start(closeEvents chan error, logger *log.Entry, connectToAmqpFunc func(url return fmt.Errorf("caching subgraphs: %w", err) } setups := []goamqp.Setup{ - goamqp.UseLogger(logger.Error), + goamqp.UseLogger(func(s string) { logger.Error(s) }), goamqp.CloseListener(closeEvents), goamqp.WithPrefetchLimit(20), goamqp.EventStreamPublisher(publisher), @@ -169,7 +168,7 @@ func start(closeEvents chan error, logger *log.Entry, connectToAmqpFunc func(url defer wg.Done() err := <-closeEvents if err != nil { - logger.WithError(err).Error("received close from AMQP") + logger.With("error", err).Error("received close from AMQP") rootCancel() } }() @@ -179,8 +178,11 @@ func start(closeEvents chan error, logger *log.Entry, connectToAmqpFunc func(url defer wg.Done() <-rootCtx.Done() - if err := httpSrv.Close(); err != nil { - logger.WithError(err).Error("close http server") + shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second) + defer shutdownRelease() + + if err := httpSrv.Shutdown(shutdownCtx); err != nil { + logger.With("error", err).Error("close http server") } close(sigint) close(closeEvents) @@ -235,10 +237,10 @@ func start(closeEvents chan error, logger *log.Entry, connectToAmqpFunc func(url ), )) - logger.Infof("connect to http://localhost:%d/ for GraphQL playground", cli.Port) + logger.Info(fmt.Sprintf("connect to http://localhost:%d/ for GraphQL playground", cli.Port)) - if err := httpSrv.ListenAndServe(); err != nil { - logger.WithError(err).Error("listen http") + if err := httpSrv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + logger.With("error", err).Error("listen http") } }() @@ -287,7 +289,7 @@ func healthFunc(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("OK")) } -func setupSentry(logger log.Interface, args SentryConfig) error { +func setupSentry(logger *slog.Logger, args SentryConfig) error { if args.Environment == "" { return fmt.Errorf("no Sentry environment supplied, exiting") } @@ -315,7 +317,7 @@ func setupSentry(logger log.Interface, args SentryConfig) error { if err := sentry.Init(cfg); err != nil { return fmt.Errorf("sentry setup: %w", err) } - logger.Infof("configured Sentry for env: %s", args.Environment) + logger.With("environment", args.Environment).Info("configured Sentry") return nil } diff --git a/graph/resolver.go b/graph/resolver.go index a27d195..5dc5e59 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -3,8 +3,8 @@ package graph import ( "context" "fmt" + "log/slog" - "github.com/apex/log" "gitlab.com/unboundsoftware/eventsourced/eventsourced" "gitlab.com/unboundsoftware/schemas/cache" @@ -26,7 +26,7 @@ type Publisher interface { type Resolver struct { EventStore eventsourced.EventStore Publisher Publisher - Logger log.Interface + Logger *slog.Logger Cache *cache.Cache } diff --git a/logging/log.go b/logging/log.go new file mode 100644 index 0000000..e102615 --- /dev/null +++ b/logging/log.go @@ -0,0 +1,52 @@ +package logging + +import ( + "context" + "log/slog" + "os" +) + +type Logger interface { + Info(msg string, args ...any) + Warn(msg string, args ...any) + Error(msg string, args ...any) +} + +var defaultLogger *slog.Logger + +type contextKey string + +const loggerKey = contextKey("logger") + +func SetupLogger(logLevel, serviceName, buildVersion string) *slog.Logger { + var leveler slog.LevelVar + + err := leveler.UnmarshalText([]byte(logLevel)) + + defaultLogger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: false, + Level: leveler.Level(), + ReplaceAttr: nil, + })).With("service", serviceName).With("version", buildVersion) + + if err != nil { + defaultLogger.With("err", err).Error("Failed to parse log level") + os.Exit(1) + } + slog.SetDefault(defaultLogger) + return defaultLogger +} + +// ContextWithLogger returns a new Context with the logger attached +func ContextWithLogger(ctx context.Context, logger *slog.Logger) context.Context { + return context.WithValue(ctx, loggerKey, logger) +} + +// LoggerFromContext returns a logger from the passed context or the default logger +func LoggerFromContext(ctx context.Context) *slog.Logger { + logger := ctx.Value(loggerKey) + if l, ok := logger.(*slog.Logger); ok { + return l + } + return defaultLogger +} diff --git a/logging/mocklogger.go b/logging/mocklogger.go new file mode 100644 index 0000000..37bf6cc --- /dev/null +++ b/logging/mocklogger.go @@ -0,0 +1,48 @@ +package logging + +import ( + "bytes" + "log/slog" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func NewMockLogger() *MockLogger { + logged := &bytes.Buffer{} + + return &MockLogger{ + logged: logged, + logger: slog.New(slog.NewTextHandler(logged, &slog.HandlerOptions{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == "time" { + return slog.Attr{} + } + return a + }, + })), + } +} + +type MockLogger struct { + logger *slog.Logger + logged *bytes.Buffer +} + +func (m *MockLogger) Logger() *slog.Logger { + return m.logger +} + +func (m *MockLogger) Check(t testing.TB, wantLogged []string) { + var gotLogged []string + if m.logged.String() != "" { + gotLogged = strings.Split(m.logged.String(), "\n") + gotLogged = gotLogged[:len(gotLogged)-1] + } + if len(wantLogged) == 0 { + assert.Empty(t, gotLogged) + return + } + assert.Equal(t, wantLogged, gotLogged) +}