Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions .github/gallery-agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"slices"
"strings"

"github.com/ghodss/yaml"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
cogito "github.com/mudler/cogito"

Expand Down Expand Up @@ -52,6 +53,11 @@ func cleanTextContent(text string) string {
return stripThinkingTags(strings.TrimRight(result, "\n"))
}

type galleryModel struct {
Name string `yaml:"name"`
Urls []string `yaml:"urls"`
}

// isModelExisting checks if a specific model ID exists in the gallery using text search
func isModelExisting(modelID string) (bool, error) {
indexPath := getGalleryIndexPath()
Expand All @@ -60,9 +66,20 @@ func isModelExisting(modelID string) (bool, error) {
return false, fmt.Errorf("failed to read %s: %w", indexPath, err)
}

contentStr := string(content)
// Simple text search - if the model ID appears anywhere in the file, it exists
return strings.Contains(contentStr, modelID), nil
var galleryModels []galleryModel

err = yaml.Unmarshal(content, &galleryModels)
if err != nil {
return false, fmt.Errorf("failed to unmarshal %s: %w", indexPath, err)
}

for _, galleryModel := range galleryModels {
if slices.Contains(galleryModel.Urls, modelID) {
return true, nil
}
}

return false, nil
}

// filterExistingModels removes models that already exist in the gallery
Expand Down
11 changes: 11 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ type RunCMD struct {
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"`
TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"`
AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"`

Version bool
Expand Down Expand Up @@ -152,6 +154,15 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
opts = append(opts, config.DisableRuntimeSettings)
}

if r.EnableTracing {
opts = append(opts, config.EnableTracing)
}

if r.EnableTracing {
opts = append(opts, config.EnableTracing)
}
opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems))

token := ""
if r.Peer2Peer || r.Peer2PeerToken != "" {
xlog.Info("P2P mode enabled")
Expand Down
23 changes: 23 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type ApplicationConfig struct {
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug bool
EnableTracing bool
TracingMaxItems int
GeneratedContentDir string

UploadDir string
Expand Down Expand Up @@ -97,6 +99,7 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
AgentJobRetentionDays: 30, // Default: 30 days
LRUEvictionMaxRetries: 30, // Default: 30 retries
LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second
TracingMaxItems: 1024,
PathWithoutAuth: []string{
"/static/",
"/generated-audio/",
Expand Down Expand Up @@ -165,6 +168,10 @@ var EnableWatchDog = func(o *ApplicationConfig) {
o.WatchDog = true
}

var EnableTracing = func(o *ApplicationConfig) {
o.EnableTracing = true
}

var EnableWatchDogIdleCheck = func(o *ApplicationConfig) {
o.WatchDog = true
o.WatchDogIdle = true
Expand Down Expand Up @@ -418,6 +425,12 @@ func WithDebug(debug bool) AppOption {
}
}

func WithTracingMaxItems(items int) AppOption {
return func(o *ApplicationConfig) {
o.TracingMaxItems = items
}
}

func WithGeneratedContentDir(generatedContentDir string) AppOption {
return func(o *ApplicationConfig) {
o.GeneratedContentDir = generatedContentDir
Expand Down Expand Up @@ -543,6 +556,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
contextSize := o.ContextSize
f16 := o.F16
debug := o.Debug
tracingMaxItems := o.TracingMaxItems
enableTracing := o.EnableTracing
cors := o.CORS
csrf := o.CSRF
corsAllowOrigins := o.CORSAllowOrigins
Expand Down Expand Up @@ -599,6 +614,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
ContextSize: &contextSize,
F16: &f16,
Debug: &debug,
TracingMaxItems: &tracingMaxItems,
EnableTracing: &enableTracing,
CORS: &cors,
CSRF: &csrf,
CORSAllowOrigins: &corsAllowOrigins,
Expand Down Expand Up @@ -713,6 +730,12 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
if settings.Debug != nil {
o.Debug = *settings.Debug
}
if settings.EnableTracing != nil {
o.EnableTracing = *settings.EnableTracing
}
if settings.TracingMaxItems != nil {
o.TracingMaxItems = *settings.TracingMaxItems
}
if settings.CORS != nil {
o.CORS = *settings.CORS
}
Expand Down
10 changes: 6 additions & 4 deletions core/config/runtime_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ type RuntimeSettings struct {
LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s)

// Performance settings
Threads *int `json:"threads,omitempty"`
ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"`
Threads *int `json:"threads,omitempty"`
ContextSize *int `json:"context_size,omitempty"`
F16 *bool `json:"f16,omitempty"`
Debug *bool `json:"debug,omitempty"`
EnableTracing *bool `json:"enable_tracing,omitempty"`
TracingMaxItems *int `json:"tracing_max_items,omitempty"`

// Security/CORS settings
CORS *bool `json:"cors,omitempty"`
Expand Down
156 changes: 156 additions & 0 deletions core/http/middleware/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package middleware

import (
"bytes"
"github.com/emirpasic/gods/v2/queues/circularbuffer"
"io"
"net/http"
"sort"
"sync"
"time"

"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/xlog"
)

type APIExchangeRequest struct {
Method string `json:"method"`
Path string `json:"path"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
}

type APIExchangeResponse struct {
Status int `json:"status"`
Headers *http.Header `json:"headers"`
Body *[]byte `json:"body"`
}

type APIExchange struct {
Timestamp time.Time `json:"timestamp"`
Request APIExchangeRequest `json:"request"`
Response APIExchangeResponse `json:"response"`
}

var traceBuffer *circularbuffer.Queue[APIExchange]
var mu sync.Mutex
var logChan = make(chan APIExchange, 100)

type bodyWriter struct {
http.ResponseWriter
body *bytes.Buffer
}

func (w *bodyWriter) Write(b []byte) (int, error) {
w.body.Write(b)
return w.ResponseWriter.Write(b)
}

func (w *bodyWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

// TraceMiddleware intercepts and logs JSON API requests and responses
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
if app.ApplicationConfig().EnableTracing && traceBuffer == nil {
traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems)

go func() {
for exchange := range logChan {
mu.Lock()
traceBuffer.Enqueue(exchange)
mu.Unlock()
}
}()
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !app.ApplicationConfig().EnableTracing {
return next(c)
}

if c.Request().Header.Get("Content-Type") != "application/json" {
return next(c)
}

body, err := io.ReadAll(c.Request().Body)
if err != nil {
xlog.Error("Failed to read request body")
return err
}

// Restore the body for downstream handlers
c.Request().Body = io.NopCloser(bytes.NewBuffer(body))

startTime := time.Now()

// Wrap response writer to capture body
resBody := new(bytes.Buffer)
mw := &bodyWriter{
ResponseWriter: c.Response().Writer,
body: resBody,
}
c.Response().Writer = mw

err = next(c)
if err != nil {
c.Response().Writer = mw.ResponseWriter // Restore original writer if error
return err
}

// Create exchange log
requestHeaders := c.Request().Header.Clone()
requestBody := make([]byte, len(body))
copy(requestBody, body)
responseHeaders := c.Response().Header().Clone()
responseBody := make([]byte, resBody.Len())
copy(responseBody, resBody.Bytes())
exchange := APIExchange{
Timestamp: startTime,
Request: APIExchangeRequest{
Method: c.Request().Method,
Path: c.Path(),
Headers: &requestHeaders,
Body: &requestBody,
},
Response: APIExchangeResponse{
Status: c.Response().Status,
Headers: &responseHeaders,
Body: &responseBody,
},
}

select {
case logChan <- exchange:
default:
xlog.Warn("Trace channel full, dropping trace")
}

return nil
}
}
}

// GetTraces returns a copy of the logged API exchanges for display
func GetTraces() []APIExchange {
mu.Lock()
traces := traceBuffer.Values()
mu.Unlock()

sort.Slice(traces, func(i, j int) bool {
return traces[i].Timestamp.Before(traces[j].Timestamp)
})

return traces
}

// ClearTraces clears the in-memory logs
func ClearTraces() {
mu.Lock()
traceBuffer.Clear()
mu.Unlock()
}
14 changes: 12 additions & 2 deletions core/http/routes/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ func RegisterOpenAIRoutes(app *echo.Echo,
re *middleware.RequestExtractor,
application *application.Application) {
// openAI compatible API endpoint
traceMiddleware := middleware.TraceMiddleware(application)

// realtime
// TODO: Modify/disable the API key middleware for this endpoint to allow ephemeral keys created by sessions
app.GET("/v1/realtime", openai.Realtime(application))
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application))
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application))
app.POST("/v1/realtime/sessions", openai.RealtimeTranscriptionSession(application), traceMiddleware)
app.POST("/v1/realtime/transcription_session", openai.RealtimeTranscriptionSession(application), traceMiddleware)

// chat
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
chatMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -41,6 +43,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// edit
editHandler := openai.EditEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
editMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
Expand All @@ -59,6 +62,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// completion
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
completionMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
Expand All @@ -78,6 +82,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// MCPcompletion
mcpCompletionHandler := openai.MCPCompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
mcpCompletionMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -95,6 +100,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// embeddings
embeddingHandler := openai.EmbeddingsEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
embeddingMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
re.BuildConstantDefaultModelNameMiddleware("gpt-4o"),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
Expand All @@ -113,6 +119,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,

audioHandler := openai.TranscriptEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -130,6 +137,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,

audioSpeechHandler := localai.TTSEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
audioSpeechMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }),
}
Expand All @@ -140,6 +148,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// images
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
imageMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
// Default: use the first available image generation model
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_IMAGE)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
Expand All @@ -164,6 +173,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
videoMiddleware := []echo.MiddlewareFunc{
traceMiddleware,
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_VIDEO)),
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
func(next echo.HandlerFunc) echo.HandlerFunc {
Expand Down
Loading
Loading