/*
 * Lugh - Pure C LLM Inference Engine for Perl
 * 
 * Built on ggml tensor library
 * Thread-safe using registry pattern
 */

#define PERL_NO_GET_CONTEXT
#include "EXTERN.h"
#include "perl.h"
#include "XSUB.h"

#include "ppport.h"

#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>
#include <ggml-cpu.h>
#include <gguf.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* ============================================================================
 * Thread Safety Configuration
 * ============================================================================ */

#define MAX_CONTEXTS 4096
#define MAX_TENSORS  65536

#ifdef USE_ITHREADS
static perl_mutex context_mutex;
static perl_mutex tensor_mutex;
static int mutex_initialized = 0;

#define CONTEXT_LOCK()   MUTEX_LOCK(&context_mutex)
#define CONTEXT_UNLOCK() MUTEX_UNLOCK(&context_mutex)
#define TENSOR_LOCK()    MUTEX_LOCK(&tensor_mutex)
#define TENSOR_UNLOCK()  MUTEX_UNLOCK(&tensor_mutex)

#define INIT_MUTEXES() do { \
    if (!mutex_initialized) { \
        MUTEX_INIT(&context_mutex); \
        MUTEX_INIT(&tensor_mutex); \
        mutex_initialized = 1; \
    } \
} while(0)

#else
#define CONTEXT_LOCK()
#define CONTEXT_UNLOCK()
#define TENSOR_LOCK()
#define TENSOR_UNLOCK()
#define INIT_MUTEXES()
#endif

/* ============================================================================
 * State structures
 * ============================================================================ */

typedef struct {
    struct ggml_context *ctx;
    size_t mem_size;
    int id;
    int active;
} LughContext;

typedef struct {
    struct ggml_tensor *tensor;
    int context_id;  /* ID of owning context */
    int id;
    int active;
} LughTensor;

typedef struct {
    struct gguf_context *gguf;
    struct ggml_context *ctx;     /* Context for tensor data */
    char *filename;
    int id;
    int active;
    /* Model metadata */
    int64_t n_tensors;
    int64_t n_kv;
    char *architecture;
} LughModel;

/* ============================================================================
 * Global Registries (thread-safe via integer IDs)
 * ============================================================================ */

static LughContext* context_registry[MAX_CONTEXTS] = {NULL};
static LughTensor*  tensor_registry[MAX_TENSORS]   = {NULL};
static LughModel*   model_registry[MAX_CONTEXTS]   = {NULL};
static int next_context_id = 1;
static int next_tensor_id  = 1;
static int next_model_id   = 1;

/* Allocate a new context ID */
static int alloc_context_id(void) {
    int id = -1;
    CONTEXT_LOCK();
    for (int i = 0; i < MAX_CONTEXTS; i++) {
        int check_id = (next_context_id + i) % MAX_CONTEXTS;
        if (check_id == 0) check_id = 1;  /* Skip 0 */
        if (context_registry[check_id] == NULL) {
            id = check_id;
            next_context_id = (id + 1) % MAX_CONTEXTS;
            if (next_context_id == 0) next_context_id = 1;
            break;
        }
    }
    CONTEXT_UNLOCK();
    return id;
}

/* Get context by ID */
static LughContext* get_context_by_id(int id) {
    LughContext *lctx = NULL;
    if (id <= 0 || id >= MAX_CONTEXTS) return NULL;
    CONTEXT_LOCK();
    lctx = context_registry[id];
    if (lctx && !lctx->active) lctx = NULL;
    CONTEXT_UNLOCK();
    return lctx;
}

/* Allocate a new tensor ID */
static int alloc_tensor_id(void) {
    int id = -1;
    TENSOR_LOCK();
    for (int i = 0; i < MAX_TENSORS; i++) {
        int check_id = (next_tensor_id + i) % MAX_TENSORS;
        if (check_id == 0) check_id = 1;
        if (tensor_registry[check_id] == NULL) {
            id = check_id;
            next_tensor_id = (id + 1) % MAX_TENSORS;
            if (next_tensor_id == 0) next_tensor_id = 1;
            break;
        }
    }
    TENSOR_UNLOCK();
    return id;
}

/* Get tensor by ID */
static LughTensor* get_tensor_by_id(int id) {
    LughTensor *lt = NULL;
    if (id <= 0 || id >= MAX_TENSORS) return NULL;
    TENSOR_LOCK();
    lt = tensor_registry[id];
    if (lt && !lt->active) lt = NULL;
    TENSOR_UNLOCK();
    return lt;
}

/* Allocate a new model ID */
static int alloc_model_id(void) {
    int id = -1;
    CONTEXT_LOCK();
    for (int i = 0; i < MAX_CONTEXTS; i++) {
        int check_id = (next_model_id + i) % MAX_CONTEXTS;
        if (check_id == 0) check_id = 1;
        if (model_registry[check_id] == NULL) {
            id = check_id;
            next_model_id = (id + 1) % MAX_CONTEXTS;
            if (next_model_id == 0) next_model_id = 1;
            break;
        }
    }
    CONTEXT_UNLOCK();
    return id;
}

/* Get model by ID */
static LughModel* get_model_by_id(int id) {
    LughModel *lm = NULL;
    if (id <= 0 || id >= MAX_CONTEXTS) return NULL;
    CONTEXT_LOCK();
    lm = model_registry[id];
    if (lm && !lm->active) lm = NULL;
    CONTEXT_UNLOCK();
    return lm;
}

/* ============================================================================
 * Magic vtable for cleanup
 * ============================================================================ */

static int lugh_context_free(pTHX_ SV *sv, MAGIC *mg) {
    int id = (int)(IV)mg->mg_ptr;
    LughContext *lctx = get_context_by_id(id);
    if (lctx) {
        CONTEXT_LOCK();
        if (lctx->ctx) {
            ggml_free(lctx->ctx);
            lctx->ctx = NULL;
        }
        lctx->active = 0;
        context_registry[id] = NULL;
        Safefree(lctx);
        CONTEXT_UNLOCK();
    }
    return 0;
}

static MGVTBL lugh_context_vtbl = {
    NULL,                /* get */
    NULL,                /* set */
    NULL,                /* len */
    NULL,                /* clear */
    lugh_context_free,   /* free */
    NULL,                /* copy */
    NULL,                /* dup */
    NULL                 /* local */
};

static int lugh_model_free(pTHX_ SV *sv, MAGIC *mg) {
    int id = (int)(IV)mg->mg_ptr;
    LughModel *lm = get_model_by_id(id);
    if (lm) {
        CONTEXT_LOCK();
        if (lm->ctx) {
            ggml_free(lm->ctx);
            lm->ctx = NULL;
        }
        if (lm->gguf) {
            gguf_free(lm->gguf);
            lm->gguf = NULL;
        }
        if (lm->filename) {
            Safefree(lm->filename);
            lm->filename = NULL;
        }
        if (lm->architecture) {
            Safefree(lm->architecture);
            lm->architecture = NULL;
        }
        lm->active = 0;
        model_registry[id] = NULL;
        Safefree(lm);
        CONTEXT_UNLOCK();
    }
    return 0;
}

static MGVTBL lugh_model_vtbl = {
    NULL,                /* get */
    NULL,                /* set */
    NULL,                /* len */
    NULL,                /* clear */
    lugh_model_free,     /* free */
    NULL,                /* copy */
    NULL,                /* dup */
    NULL                 /* local */
};

/* Helper to get LughModel from SV */
static LughModel* get_lugh_model(pTHX_ SV *sv) {
    MAGIC *mg;
    int id;
    LughModel *lm;
    
    if (!sv_isobject(sv))
        croak("Not a Lugh::Model object");
    
    sv = SvRV(sv);
    mg = mg_find(sv, PERL_MAGIC_ext);
    if (!mg || mg->mg_virtual != &lugh_model_vtbl)
        croak("Invalid Lugh::Model object");
    
    id = (int)(IV)mg->mg_ptr;
    lm = get_model_by_id(id);
    if (!lm)
        croak("Lugh::Model has been destroyed");
    
    return lm;
}

/* Helper to get LughContext from SV */
static LughContext* get_lugh_context(pTHX_ SV *sv) {
    MAGIC *mg;
    int id;
    LughContext *lctx;
    
    if (!sv_isobject(sv))
        croak("Not a Lugh::Context object");
    
    sv = SvRV(sv);
    mg = mg_find(sv, PERL_MAGIC_ext);
    if (!mg || mg->mg_virtual != &lugh_context_vtbl)
        croak("Invalid Lugh::Context object");
    
    id = (int)(IV)mg->mg_ptr;
    lctx = get_context_by_id(id);
    if (!lctx)
        croak("Lugh::Context has been destroyed");
    
    return lctx;
}

/* ============================================================================
 * XS Functions
 * ============================================================================ */

MODULE = Lugh    PACKAGE = Lugh

PROTOTYPES: DISABLE

BOOT:
    INIT_MUTEXES();

const char *
version()
CODE:
    RETVAL = "0.01";
OUTPUT:
    RETVAL

const char *
ggml_version()
CODE:
    /* Return ggml build info */
    RETVAL = "ggml 0.9.5";
OUTPUT:
    RETVAL

MODULE = Lugh    PACKAGE = Lugh::Context

SV *
new(class, ...)
    char *class
PREINIT:
    LughContext *lctx;
    size_t mem_size = 16 * 1024 * 1024;  /* 16MB default */
    struct ggml_init_params params;
    SV *sv;
    int i, id;
CODE:
    INIT_MUTEXES();
    
    /* Allocate context ID first */
    id = alloc_context_id();
    if (id < 0) {
        croak("Maximum number of contexts (%d) reached", MAX_CONTEXTS);
    }
    
    /* Parse optional arguments */
    for (i = 1; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "mem_size")) {
                mem_size = SvUV(ST(i + 1));
            }
        }
    }
    
    /* Allocate our state */
    Newxz(lctx, 1, LughContext);
    lctx->mem_size = mem_size;
    lctx->id = id;
    lctx->active = 1;
    
    /* Initialize ggml context */
    params.mem_size = mem_size;
    params.mem_buffer = NULL;
    params.no_alloc = false;
    
    lctx->ctx = ggml_init(params);
    if (!lctx->ctx) {
        Safefree(lctx);
        croak("Failed to initialize ggml context");
    }
    
    /* Register in global registry */
    CONTEXT_LOCK();
    context_registry[id] = lctx;
    CONTEXT_UNLOCK();
    
    /* Create blessed reference with magic - store ID not pointer */
    sv = newSV(0);
    sv_magicext(sv, NULL, PERL_MAGIC_ext, &lugh_context_vtbl, INT2PTR(char*, (IV)id), 0);
    RETVAL = sv_bless(newRV_noinc(sv), gv_stashpv(class, GV_ADD));
OUTPUT:
    RETVAL

int
id(self)
    SV *self
CODE:
    LughContext *lctx = get_lugh_context(aTHX_ self);
    RETVAL = lctx->id;
OUTPUT:
    RETVAL

size_t
mem_size(self)
    SV *self
CODE:
    LughContext *lctx = get_lugh_context(aTHX_ self);
    RETVAL = lctx->mem_size;
OUTPUT:
    RETVAL

size_t
used_mem(self)
    SV *self
CODE:
    LughContext *lctx = get_lugh_context(aTHX_ self);
    RETVAL = ggml_used_mem(lctx->ctx);
OUTPUT:
    RETVAL

void
DESTROY(self)
    SV *self
CODE:
    /* Magic cleanup handles this */
    PERL_UNUSED_VAR(self);

MODULE = Lugh  PACKAGE = Lugh::Inference

=pod

=head1 Lugh::Inference

The inference engine - runs the forward pass through the model

=cut

SV *
new(class, ...)
    const char *class
PREINIT:
    LughModel *model = NULL;
    SV *model_sv = NULL;
    int i;
    int n_ctx = 2048;
    int n_threads = 4;
CODE:
    /* Parse arguments */
    if ((items - 1) % 2 != 0) {
        croak("Usage: Lugh::Inference->new(model => $model, n_ctx => 2048, n_threads => 4)");
    }
    
    for (i = 1; i < items; i += 2) {
        const char *key = SvPV_nolen(ST(i));
        SV *val = ST(i + 1);
        
        if (strEQ(key, "model")) {
            model_sv = val;
            model = get_lugh_model(aTHX_ val);
        } else if (strEQ(key, "n_ctx")) {
            n_ctx = SvIV(val);
        } else if (strEQ(key, "n_threads")) {
            n_threads = SvIV(val);
        }
    }
    
    if (!model) {
        croak("model parameter is required");
    }
    
    /* For now, return a simple blessed hash with the config */
    {
        HV *hv = newHV();
        SV *sv;
        
        hv_store(hv, "_model", 6, SvREFCNT_inc(model_sv), 0);
        hv_store(hv, "n_ctx", 5, newSViv(n_ctx), 0);
        hv_store(hv, "n_threads", 9, newSViv(n_threads), 0);
        
        /* Store model hyperparameters */
        {
            int64_t key_id;
            
            key_id = gguf_find_key(model->gguf, "llama.embedding_length");
            if (key_id >= 0) hv_store(hv, "n_embd", 6, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            
            key_id = gguf_find_key(model->gguf, "llama.block_count");
            if (key_id >= 0) hv_store(hv, "n_layer", 7, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            
            key_id = gguf_find_key(model->gguf, "llama.attention.head_count");
            if (key_id >= 0) hv_store(hv, "n_head", 6, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            
            key_id = gguf_find_key(model->gguf, "llama.attention.head_count_kv");
            if (key_id >= 0) hv_store(hv, "n_head_kv", 9, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            
            key_id = gguf_find_key(model->gguf, "llama.feed_forward_length");
            if (key_id >= 0) hv_store(hv, "n_ff", 4, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            
            key_id = gguf_find_key(model->gguf, "llama.vocab_size");
            if (key_id >= 0) {
                hv_store(hv, "n_vocab", 7, newSViv(gguf_get_val_u32(model->gguf, key_id)), 0);
            } else {
                /* Infer from tokenizer or embedding tensor */
                hv_store(hv, "n_vocab", 7, newSViv(32000), 0);  /* Default llama vocab */
            }
        }
        
        sv = newRV_noinc((SV*)hv);
        sv_bless(sv, gv_stashpv(class, GV_ADD));
        RETVAL = sv;
    }
OUTPUT:
    RETVAL

SV *
model(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    if (!SvROK(self) || SvTYPE(SvRV(self)) != SVt_PVHV) {
        croak("Not a valid Lugh::Inference object");
    }
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "_model", 6, 0);
    if (svp && *svp) {
        RETVAL = SvREFCNT_inc(*svp);
    } else {
        RETVAL = &PL_sv_undef;
    }
OUTPUT:
    RETVAL

int
n_ctx(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_ctx", 5, 0);
    RETVAL = svp ? SvIV(*svp) : 2048;
OUTPUT:
    RETVAL

int
n_vocab(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_vocab", 7, 0);
    RETVAL = svp ? SvIV(*svp) : 32000;
OUTPUT:
    RETVAL

int
n_embd(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_embd", 6, 0);
    RETVAL = svp ? SvIV(*svp) : 2048;
OUTPUT:
    RETVAL

int
n_layer(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_layer", 7, 0);
    RETVAL = svp ? SvIV(*svp) : 22;
OUTPUT:
    RETVAL

int
n_head(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_head", 6, 0);
    RETVAL = svp ? SvIV(*svp) : 32;
OUTPUT:
    RETVAL

void
forward(self, ...)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
    LughModel *model;
    int i, n_tokens = 0;
    int *tokens = NULL;
    int n_embd, n_layer, n_head, n_head_kv, n_vocab, n_threads;
    int n_ctx_orig, n_rot;
    int use_flash_attn = 0;  /* 0 = standard attention, 1 = flash attention */
    float rms_norm_eps = 1e-5f;
    float rope_freq_base = 10000.0f;
    float rope_freq_scale = 1.0f;
    int64_t key_id;
    struct ggml_context *ctx_w = NULL;  /* weights context (model->ctx) */
    struct ggml_context *ctx_c = NULL;  /* compute context */
    struct ggml_cgraph *gf = NULL;
    struct ggml_tensor *cur = NULL;
    struct ggml_tensor *inpL = NULL;
    ggml_backend_t backend = NULL;
    ggml_gallocr_t allocr = NULL;
PPCODE:
    /* Get model */
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "_model", 6, 0);
    if (!svp || !*svp) croak("No model in inference object");
    model = get_lugh_model(aTHX_ *svp);
    if (!model) croak("Invalid model");
    
    /* Parse tokens from args */
    /* Expect either an array ref or a list of tokens */
    if (items == 2 && SvROK(ST(1)) && SvTYPE(SvRV(ST(1))) == SVt_PVAV) {
        /* Array reference passed */
        AV *av = (AV*)SvRV(ST(1));
        n_tokens = av_len(av) + 1;
        if (n_tokens == 0) {
            croak("forward() requires at least one token");
        }
        Newx(tokens, n_tokens, int);
        for (i = 0; i < n_tokens; i++) {
            SV **elem = av_fetch(av, i, 0);
            tokens[i] = elem ? SvIV(*elem) : 0;
        }
    } else {
        /* List of tokens passed directly */
        for (i = 1; i < items; i++) {
            n_tokens++;
        }
        if (n_tokens == 0) {
            croak("forward() requires at least one token");
        }
        Newx(tokens, n_tokens, int);
        for (i = 0; i < n_tokens; i++) {
            tokens[i] = SvIV(ST(i + 1));
        }
    }
    

    
    /* Get hyperparameters */
    svp = hv_fetch(hv, "n_embd", 6, 0);
    n_embd = svp ? SvIV(*svp) : 2048;
    svp = hv_fetch(hv, "n_layer", 7, 0);
    n_layer = svp ? SvIV(*svp) : 22;
    svp = hv_fetch(hv, "n_head", 6, 0);
    n_head = svp ? SvIV(*svp) : 32;
    svp = hv_fetch(hv, "n_head_kv", 9, 0);
    n_head_kv = svp ? SvIV(*svp) : 4;
    svp = hv_fetch(hv, "n_vocab", 7, 0);
    n_vocab = svp ? SvIV(*svp) : 32000;
    svp = hv_fetch(hv, "n_threads", 9, 0);
    n_threads = svp ? SvIV(*svp) : 4;
    svp = hv_fetch(hv, "flash_attn", 10, 0);
    use_flash_attn = svp ? SvIV(*svp) : 0;
    
    /* Get normalization parameter */
    key_id = gguf_find_key(model->gguf, "llama.attention.layer_norm_rms_epsilon");
    if (key_id >= 0) rms_norm_eps = gguf_get_val_f32(model->gguf, key_id);
    
    /* Get RoPE parameters from model */
    key_id = gguf_find_key(model->gguf, "llama.rope.freq_base");
    if (key_id >= 0) rope_freq_base = gguf_get_val_f32(model->gguf, key_id);
    
    key_id = gguf_find_key(model->gguf, "llama.rope.dimension_count");
    n_rot = (key_id >= 0) ? gguf_get_val_u32(model->gguf, key_id) : (n_embd / n_head);
    
    key_id = gguf_find_key(model->gguf, "llama.context_length");
    n_ctx_orig = (key_id >= 0) ? gguf_get_val_u32(model->gguf, key_id) : 2048;
    
    /* Use model's context directly for weights - data is already loaded */
    ctx_w = model->ctx;
    
    /* Initialize CPU backend */
    backend = ggml_backend_cpu_init();
    if (!backend) {
        Safefree(tokens);
        croak("Failed to initialize CPU backend");
    }
    ggml_backend_cpu_set_n_threads(backend, n_threads);
    
    /* Create compute context for intermediate tensors */
    {
        size_t ctx_size = 512 * 1024 * 1024; /* 512MB for computation */
        struct ggml_init_params params = {
            .mem_size   = ctx_size,
            .mem_buffer = NULL,
            .no_alloc   = true,
        };
        ctx_c = ggml_init(params);
        if (!ctx_c) {
            ggml_backend_free(backend);
            Safefree(tokens);
            croak("Failed to create compute context");
        }
    };
    
    /* Build the forward pass graph */
    {
        struct ggml_tensor *tok_embd = ggml_get_tensor(ctx_w, "token_embd.weight");
        struct ggml_tensor *output_norm = ggml_get_tensor(ctx_w, "output_norm.weight");
        struct ggml_tensor *output = ggml_get_tensor(ctx_w, "output.weight");
        struct ggml_tensor *pos;  /* Position tensor - created once, reused for all layers */
        int head_dim = n_embd / n_head;
        int layer;
        
        /* Support tied embeddings: if output.weight is missing, use token_embd.weight */
        if (!output) {
            output = tok_embd;
        }
        
        if (!tok_embd) {
            ggml_free(ctx_c);
            ggml_backend_free(backend);
            Safefree(tokens);
            croak("Required tensors not found in model");
        }
        
        /* Create position tensor for RoPE - one tensor for all layers */
        pos = ggml_new_tensor_1d(ctx_c, GGML_TYPE_I32, n_tokens);
        ggml_set_name(pos, "pos");
        
        /* Create causal attention mask - [n_tokens, n_tokens] */
        /* Lower triangular with -inf in upper triangle */
        struct ggml_tensor *kq_mask = ggml_new_tensor_2d(ctx_c, GGML_TYPE_F32, n_tokens, n_tokens);
        ggml_set_name(kq_mask, "kq_mask");
        
        /* Create input embedding: lookup tokens in embedding table */
        {
            struct ggml_tensor *inp_tokens = ggml_new_tensor_1d(ctx_c, GGML_TYPE_I32, n_tokens);
            ggml_set_name(inp_tokens, "inp_tokens");
            

            
            /* Will set data after allocation */
            inpL = ggml_get_rows(ctx_c, tok_embd, inp_tokens);
            ggml_set_name(inpL, "inp_embd");
            

        }
        
        cur = inpL;
        
        /* Process each transformer layer */
        for (layer = 0; layer < n_layer; layer++) {
            char name[64];
            struct ggml_tensor *attn_norm, *ffn_norm;
            struct ggml_tensor *wq, *wk, *wv, *wo;
            struct ggml_tensor *ffn_gate, *ffn_up, *ffn_down;
            struct ggml_tensor *residual;
            
            /* Get layer weights */
            snprintf(name, sizeof(name), "blk.%d.attn_norm.weight", layer);
            attn_norm = ggml_get_tensor(ctx_w, name);
            
            snprintf(name, sizeof(name), "blk.%d.attn_q.weight", layer);
            wq = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.attn_k.weight", layer);
            wk = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.attn_v.weight", layer);
            wv = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.attn_output.weight", layer);
            wo = ggml_get_tensor(ctx_w, name);
            
            snprintf(name, sizeof(name), "blk.%d.ffn_norm.weight", layer);
            ffn_norm = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.ffn_gate.weight", layer);
            ffn_gate = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.ffn_up.weight", layer);
            ffn_up = ggml_get_tensor(ctx_w, name);
            snprintf(name, sizeof(name), "blk.%d.ffn_down.weight", layer);
            ffn_down = ggml_get_tensor(ctx_w, name);
            
            if (!attn_norm || !wq || !wk || !wv || !wo || !ffn_norm || !ffn_gate || !ffn_up || !ffn_down) {
                continue;  /* Skip layers with missing weights */
            }
            
            residual = cur;
            
            /* RMS Norm before attention */
            cur = ggml_rms_norm(ctx_c, cur, rms_norm_eps);
            cur = ggml_mul(ctx_c, cur, attn_norm);
            
            /* Self-attention: Q, K, V projections */
            {
                struct ggml_tensor *q, *k, *v;
                struct ggml_tensor *attn_out;
                int n_kv_dim = n_head_kv * head_dim;  /* 256 for TinyLlama */
                int n_rep = n_head / n_head_kv;  /* how many times to repeat KV heads */
                

                
                /* Q, K, V projections */
                /* cur: [n_embd, n_tokens], wq: [n_embd, n_embd] */
                /* mul_mat(wq, cur) requires wq->ne[0] == cur->ne[0] */
                /* Result: [wq->ne[1], cur->ne[1]] = [n_embd, n_tokens] */
                q = ggml_mul_mat(ctx_c, wq, cur);  /* [n_embd, n_tokens] */
                k = ggml_mul_mat(ctx_c, wk, cur);  /* [n_kv_dim, n_tokens] */
                v = ggml_mul_mat(ctx_c, wv, cur);  /* [n_kv_dim, n_tokens] */
                

                
                /* Reshape for attention heads */
                /* Q: [n_embd, n_tokens] -> [head_dim, n_head, n_tokens] */
                q = ggml_reshape_3d(ctx_c, q, head_dim, n_head, n_tokens);
                /* K, V: [n_kv_dim, n_tokens] -> [head_dim, n_head_kv, n_tokens] */
                k = ggml_reshape_3d(ctx_c, k, head_dim, n_head_kv, n_tokens);
                v = ggml_reshape_3d(ctx_c, v, head_dim, n_head_kv, n_tokens);
                
                /* Apply RoPE (rotary positional embeddings) - use shared pos tensor */
                /* ggml_rope_ext parameters:
                 *   a = input tensor, b = position tensor, c = freq factors (NULL for standard)
                 *   n_dims = dimensions to rotate (n_rot from model)
                 *   mode = 0 for standard RoPE
                 *   n_ctx_orig = original context length
                 *   freq_base = base frequency (10000 for llama)
                 *   freq_scale = 1.0 for no scaling
                 *   ext_factor, attn_factor, beta_fast, beta_slow = YaRN params (0 for standard)
                 */
                q = ggml_rope_ext(ctx_c, q, pos, NULL, n_rot, 0, n_ctx_orig, 
                                  rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f);
                k = ggml_rope_ext(ctx_c, k, pos, NULL, n_rot, 0, n_ctx_orig,
                                  rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f);
                
                /* For GQA: skip repeat for now to debug */
                /* TODO: fix GQA repeat */
                (void)n_rep;  /* suppress unused warning */
                
                if (use_flash_attn) {
                    /* Flash Attention path */
                    /* ggml_flash_attn_ext expects:
                     *   q: [n_embd_k, n_batch, n_head, ne3]
                     *   k: [n_embd_k, n_kv, n_head_kv, ne3]
                     *   v: [n_embd_v, n_kv, n_head_kv, ne3] - NOT transposed
                     *   mask: [n_kv, n_batch, ne32, ne33] or NULL for causal
                     *   res: [n_embd_v, n_head, n_batch, ne3] - permuted output
                     *
                     * Our tensors after rope:
                     *   q: [head_dim, n_head, n_tokens]
                     *   k: [head_dim, n_head_kv, n_tokens]
                     *   v: [head_dim, n_head_kv, n_tokens]
                     */
                    float scale = 1.0f / sqrtf((float)head_dim);
                    
                    /* Reshape for flash attention: add batch dimension */
                    /* q: [head_dim, n_head, n_tokens] -> [head_dim, n_tokens, n_head, 1] */
                    struct ggml_tensor *q_fa = ggml_cont(ctx_c, ggml_permute(ctx_c, q, 0, 2, 1, 3));
                    /* k: [head_dim, n_head_kv, n_tokens] -> [head_dim, n_tokens, n_head_kv, 1] */
                    struct ggml_tensor *k_fa = ggml_cont(ctx_c, ggml_permute(ctx_c, k, 0, 2, 1, 3));
                    /* v: [head_dim, n_head_kv, n_tokens] -> [head_dim, n_tokens, n_head_kv, 1] */
                    struct ggml_tensor *v_fa = ggml_cont(ctx_c, ggml_permute(ctx_c, v, 0, 2, 1, 3));
                    
                    /* Flash attention - NULL mask means causal masking */
                    /* max_bias = 0.0 (no ALiBi), logit_softcap = 0.0 (disabled) */
                    attn_out = ggml_flash_attn_ext(ctx_c, q_fa, k_fa, v_fa, NULL, scale, 0.0f, 0.0f);
                    
                    /* Result is [head_dim, n_head, n_tokens, 1] */
                    /* Reshape to [head_dim, n_head, n_tokens] */
                    attn_out = ggml_reshape_3d(ctx_c, attn_out, head_dim, n_head, n_tokens);
                } else {
                    /* Standard scaled dot-product attention */
                    /* Permute for attention: [head_dim, n_head, n_tokens] -> [head_dim, n_tokens, n_head] */
                    q = ggml_permute(ctx_c, q, 0, 2, 1, 3);
                    k = ggml_permute(ctx_c, k, 0, 2, 1, 3);
                    v = ggml_permute(ctx_c, v, 0, 2, 1, 3);
                    
                    /* Q,K,V after permute: [head_dim, n_tokens, n_head] (or n_head_kv for K,V) */
                    {
                        float scale = 1.0f / sqrtf((float)head_dim);
                        struct ggml_tensor *kq;
                        
                        /* Compute attention scores: QK^T */
                        /* For GQA: k->ne[2]=4, q->ne[2]=32, ggml broadcasts k over q */
                        kq = ggml_mul_mat(ctx_c, k, q);  /* [n_tokens, n_tokens, n_head] */
                        
                        /* Scale attention scores */
                        kq = ggml_scale(ctx_c, kq, scale);
                        
                        /* Apply causal mask using diag_mask_inf (sets upper triangle to -inf) */
                        kq = ggml_diag_mask_inf(ctx_c, kq, 0);
                        
                        /* Softmax */
                        kq = ggml_soft_max(ctx_c, kq);
                        
                        /* Apply attention weights to values */
                        /* kq: [n_tokens, n_tokens, n_head], v: [head_dim, n_tokens, n_head_kv] */
                        /* For the multiply: v needs to be transposed to [n_tokens, head_dim, n_head_kv] */
                        struct ggml_tensor *v_t = ggml_cont(ctx_c, ggml_transpose(ctx_c, v));
                        /* mul_mat broadcasts: kq @ v_t -> [head_dim, n_tokens, n_head] */
                        attn_out = ggml_mul_mat(ctx_c, v_t, kq);
                        
                        /* Permute to [head_dim, n_head, n_tokens] and make contiguous */
                        attn_out = ggml_cont(ctx_c, ggml_permute(ctx_c, attn_out, 0, 2, 1, 3));
                    }
                }
                
                /* Attention output is [head_dim, n_head, n_tokens] */
                /* Reshape to 2D: [n_embd, n_tokens] */
                attn_out = ggml_reshape_2d(ctx_c, attn_out, n_embd, n_tokens);
                
                /* Output projection */
                cur = ggml_mul_mat(ctx_c, wo, attn_out);
            }
            
            /* Residual connection after attention */
            cur = ggml_add(ctx_c, cur, residual);
            residual = cur;
            
            /* FFN: RMS norm -> gate/up projections -> SiLU -> down projection */
            cur = ggml_rms_norm(ctx_c, cur, rms_norm_eps);
            cur = ggml_mul(ctx_c, cur, ffn_norm);
            
            {
                struct ggml_tensor *gate = ggml_mul_mat(ctx_c, ffn_gate, cur);
                struct ggml_tensor *up = ggml_mul_mat(ctx_c, ffn_up, cur);
                
                /* SiLU activation on gate */
                gate = ggml_silu(ctx_c, gate);
                
                /* Element-wise multiply gate and up */
                cur = ggml_mul(ctx_c, gate, up);
                
                /* Down projection */
                cur = ggml_mul_mat(ctx_c, ffn_down, cur);
            }
            
            /* Residual connection after FFN */
            cur = ggml_add(ctx_c, cur, residual);
        }
        
        /* Final norm and output projection */
        if (output_norm) {
            cur = ggml_rms_norm(ctx_c, cur, rms_norm_eps);
            cur = ggml_mul(ctx_c, cur, output_norm);
        }
        
        /* Project to vocabulary (logits) */
        cur = ggml_mul_mat(ctx_c, output, cur);
        ggml_set_name(cur, "logits");
        
        /* Build graph */
        gf = ggml_new_graph(ctx_c);
        ggml_build_forward_expand(gf, cur);
    }
    
    /* Allocate compute buffer */
    allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
    if (!ggml_gallocr_reserve(allocr, gf)) {
        ggml_gallocr_free(allocr);
        ggml_free(ctx_c);
        ggml_backend_free(backend);
        Safefree(tokens);
        croak("Failed to reserve compute allocator");
    }
    
    if (!ggml_gallocr_alloc_graph(allocr, gf)) {
        ggml_gallocr_free(allocr);
        ggml_free(ctx_c);
        ggml_backend_free(backend);
        Safefree(tokens);
        croak("Failed to allocate compute graph");
    }
    
    /* Set input token data */
    {
        struct ggml_tensor *inp = ggml_graph_get_tensor(gf, "inp_tokens");
        struct ggml_tensor *pos_tensor = ggml_graph_get_tensor(gf, "pos");
        struct ggml_tensor *mask_tensor = ggml_graph_get_tensor(gf, "kq_mask");
        
        if (inp) {
            ggml_backend_tensor_set(inp, tokens, 0, n_tokens * sizeof(int));
        }
        
        /* Set position indices (0, 1, 2, ..., n_tokens-1) */
        if (pos_tensor) {
            int *positions;
            int p;
            Newx(positions, n_tokens, int);
            for (p = 0; p < n_tokens; p++) {
                positions[p] = p;
            }
            ggml_backend_tensor_set(pos_tensor, positions, 0, n_tokens * sizeof(int));
            Safefree(positions);
        }
        
        /* Set causal attention mask: 0 for allowed, -inf for masked */
        if (mask_tensor) {
            float *mask_data;
            int row, col;
            Newx(mask_data, n_tokens * n_tokens, float);
            for (row = 0; row < n_tokens; row++) {
                for (col = 0; col < n_tokens; col++) {
                    /* Causal mask: can only attend to current and previous positions */
                    if (col <= row) {
                        mask_data[row * n_tokens + col] = 0.0f;
                    } else {
                        mask_data[row * n_tokens + col] = -INFINITY;
                    }
                }
            }
            ggml_backend_tensor_set(mask_tensor, mask_data, 0, n_tokens * n_tokens * sizeof(float));
            Safefree(mask_data);
        }
    }
    
    /* Run the forward pass */
    if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) {
        ggml_gallocr_free(allocr);
        ggml_free(ctx_c);
        ggml_backend_free(backend);
        Safefree(tokens);
        croak("Failed to compute graph");
    }
    
    /* Extract logits from the last token */
    {
        struct ggml_tensor *logits_tensor = ggml_graph_get_tensor(gf, "logits");
        if (logits_tensor) {
            float *logits_data;
            int j;
            size_t logits_size = n_vocab * sizeof(float);
            
            Newx(logits_data, n_vocab, float);
            
            /* Get logits for last token */
            ggml_backend_tensor_get(logits_tensor, logits_data, 
                                    (n_tokens - 1) * n_vocab * sizeof(float), logits_size);
            
            /* Return logits as array */
            EXTEND(SP, n_vocab);
            for (j = 0; j < n_vocab; j++) {
                mPUSHn(logits_data[j]);
            }
            
            Safefree(logits_data);
        }
    }
    
    /* Cleanup - note: ctx_w is model->ctx, don't free it */
    ggml_gallocr_free(allocr);
    ggml_free(ctx_c);
    ggml_backend_free(backend);
    Safefree(tokens);

int
sample_top_p(self, logits_ref, ...)
    SV *self
    SV *logits_ref
PREINIT:
    AV *logits_av;
    int n_vocab;
    float *logits;
    float temperature = 0.8f;
    float top_p = 0.95f;
    int i;
    float max_logit = -1e9f;
    float sum = 0.0f;
    float cumsum = 0.0f;
    float threshold;
    int *indices;
CODE:
    if (!SvROK(logits_ref) || SvTYPE(SvRV(logits_ref)) != SVt_PVAV) {
        croak("logits must be an array reference");
    }
    logits_av = (AV*)SvRV(logits_ref);
    n_vocab = av_len(logits_av) + 1;
    
    /* Parse optional parameters */
    for (i = 2; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "temperature")) {
                temperature = SvNV(ST(i + 1));
            } else if (strEQ(key, "top_p")) {
                top_p = SvNV(ST(i + 1));
            }
        }
    }
    
    Newx(logits, n_vocab, float);
    Newx(indices, n_vocab, int);
    
    /* Copy logits and find max */
    for (i = 0; i < n_vocab; i++) {
        SV **svp = av_fetch(logits_av, i, 0);
        logits[i] = svp ? SvNV(*svp) : 0.0f;
        if (logits[i] > max_logit) max_logit = logits[i];
        indices[i] = i;
    }
    
    /* Apply temperature and softmax */
    for (i = 0; i < n_vocab; i++) {
        logits[i] = expf((logits[i] - max_logit) / temperature);
        sum += logits[i];
    }
    for (i = 0; i < n_vocab; i++) {
        logits[i] /= sum;
    }
    
    /* Sort by probability (simple bubble sort for now - ok for top_p sampling) */
    {
        int j, swapped;
        for (i = 0; i < n_vocab - 1; i++) {
            swapped = 0;
            for (j = 0; j < n_vocab - i - 1; j++) {
                if (logits[j] < logits[j + 1]) {
                    float tmp = logits[j];
                    int tmp_idx = indices[j];
                    logits[j] = logits[j + 1];
                    indices[j] = indices[j + 1];
                    logits[j + 1] = tmp;
                    indices[j + 1] = tmp_idx;
                    swapped = 1;
                }
            }
            if (!swapped) break;
            /* Early exit when we have enough probability mass */
            cumsum = 0;
            for (j = 0; j <= i; j++) cumsum += logits[j];
            if (cumsum >= top_p) break;
        }
    }
    
    /* Sample from top_p tokens */
    threshold = (float)rand() / (float)RAND_MAX * top_p;
    cumsum = 0.0f;
    RETVAL = indices[0];  /* Default to most likely */
    
    for (i = 0; i < n_vocab; i++) {
        cumsum += logits[i];
        if (cumsum >= threshold) {
            RETVAL = indices[i];
            break;
        }
        if (cumsum >= top_p) break;
    }
    
    Safefree(logits);
    Safefree(indices);
OUTPUT:
    RETVAL

int
sample_top_k(self, logits_ref, ...)
    SV *self
    SV *logits_ref
PREINIT:
    AV *logits_av;
    int n_vocab;
    float *logits;
    float temperature = 0.8f;
    int top_k = 40;
    int i;
    float max_logit = -1e9f;
    float sum = 0.0f;
    float threshold;
    int *indices;
CODE:
    if (!SvROK(logits_ref) || SvTYPE(SvRV(logits_ref)) != SVt_PVAV) {
        croak("logits must be an array reference");
    }
    logits_av = (AV*)SvRV(logits_ref);
    n_vocab = av_len(logits_av) + 1;
    
    /* Parse optional parameters */
    for (i = 2; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "temperature")) {
                temperature = SvNV(ST(i + 1));
            } else if (strEQ(key, "top_k")) {
                top_k = SvIV(ST(i + 1));
            }
        }
    }
    
    if (top_k <= 0 || top_k > n_vocab) top_k = n_vocab;
    
    Newx(logits, n_vocab, float);
    Newx(indices, n_vocab, int);
    
    /* Copy logits and find max */
    for (i = 0; i < n_vocab; i++) {
        SV **svp = av_fetch(logits_av, i, 0);
        logits[i] = svp ? SvNV(*svp) : 0.0f;
        if (logits[i] > max_logit) max_logit = logits[i];
        indices[i] = i;
    }
    
    /* Apply temperature and softmax */
    sum = 0.0f;
    for (i = 0; i < n_vocab; i++) {
        logits[i] = expf((logits[i] - max_logit) / temperature);
        sum += logits[i];
    }
    for (i = 0; i < n_vocab; i++) {
        logits[i] /= sum;
    }
    
    /* Partial sort to get top_k elements (selection-style) */
    {
        int j, k;
        for (k = 0; k < top_k; k++) {
            int max_idx = k;
            for (j = k + 1; j < n_vocab; j++) {
                if (logits[j] > logits[max_idx]) {
                    max_idx = j;
                }
            }
            if (max_idx != k) {
                float tmp = logits[k];
                int tmp_idx = indices[k];
                logits[k] = logits[max_idx];
                indices[k] = indices[max_idx];
                logits[max_idx] = tmp;
                indices[max_idx] = tmp_idx;
            }
        }
    }
    
    /* Renormalize top_k probabilities */
    sum = 0.0f;
    for (i = 0; i < top_k; i++) {
        sum += logits[i];
    }
    
    /* Sample from top_k tokens */
    threshold = (float)rand() / (float)RAND_MAX * sum;
    float cumsum = 0.0f;
    RETVAL = indices[0];  /* Default to most likely */
    
    for (i = 0; i < top_k; i++) {
        cumsum += logits[i];
        if (cumsum >= threshold) {
            RETVAL = indices[i];
            break;
        }
    }
    
    Safefree(logits);
    Safefree(indices);
OUTPUT:
    RETVAL

void
generate(self, tokens_ref, ...)
    SV *self
    SV *tokens_ref
PREINIT:
    HV *hv;
    SV **svp;
    LughModel *model;
    AV *tokens_av;
    AV *result_av;
    int *tokens = NULL;
    int n_tokens;
    int max_tokens = 128;
    float temperature = 0.8f;
    float top_p = 0.95f;
    int top_k = 40;
    int eos_token = 2;
    int greedy = 0;
    SV *callback = NULL;
    int i;
    int n_result;
    SV **orig_sp;
PPCODE:
    orig_sp = SP;  /* Save original stack pointer */
    
    /* Get model */
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "_model", 6, 0);
    if (!svp || !*svp) croak("No model in inference object");
    model = get_lugh_model(aTHX_ *svp);
    if (!model) croak("Invalid model");
    
    /* Parse input tokens */
    if (!SvROK(tokens_ref) || SvTYPE(SvRV(tokens_ref)) != SVt_PVAV) {
        croak("generate() requires an array reference of tokens");
    }
    tokens_av = (AV*)SvRV(tokens_ref);
    n_tokens = av_len(tokens_av) + 1;
    if (n_tokens == 0) {
        croak("generate() requires at least one token");
    }
    
    /* Parse optional parameters */
    for (i = 2; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "max_tokens")) {
                max_tokens = SvIV(ST(i + 1));
            } else if (strEQ(key, "temperature")) {
                temperature = SvNV(ST(i + 1));
            } else if (strEQ(key, "top_p")) {
                top_p = SvNV(ST(i + 1));
            } else if (strEQ(key, "top_k")) {
                top_k = SvIV(ST(i + 1));
            } else if (strEQ(key, "eos_token")) {
                eos_token = SvIV(ST(i + 1));
            } else if (strEQ(key, "greedy")) {
                greedy = SvTRUE(ST(i + 1));
            } else if (strEQ(key, "callback")) {
                if (SvROK(ST(i + 1)) && SvTYPE(SvRV(ST(i + 1))) == SVt_PVCV) {
                    callback = ST(i + 1);
                }
            }
        }
    }
    
    /* Get EOS from model if not specified */
    {
        int64_t key_id = gguf_find_key(model->gguf, "tokenizer.ggml.eos_token_id");
        if (key_id >= 0) {
            eos_token = gguf_get_val_u32(model->gguf, key_id);
        }
    }
    
    /* Initialize tokens array with prompt */
    Newx(tokens, n_tokens + max_tokens, int);
    for (i = 0; i < n_tokens; i++) {
        SV **elem = av_fetch(tokens_av, i, 0);
        tokens[i] = elem ? SvIV(*elem) : 0;
    }
    
    /* Create result array for generated tokens only */
    result_av = newAV();
    
    /* Generation loop */
    {
        int gen_count = 0;
        int current_len = n_tokens;
        
        while (gen_count < max_tokens) {
            AV *input_av;
            SV *input_ref;
            int next_token;
            AV *logits_av;
            int n_vocab;
            int j;
            
            /* Build input array for forward pass */
            input_av = newAV();
            for (j = 0; j < current_len; j++) {
                av_push(input_av, newSViv(tokens[j]));
            }
            input_ref = newRV_noinc((SV*)input_av);
            
            /* Call forward() to get logits - create logits_av outside scope */
            logits_av = newAV();
            {
                dSP;
                int count;
                ENTER;
                SAVETMPS;
                
                PUSHMARK(SP);
                XPUSHs(self);
                XPUSHs(input_ref);
                PUTBACK;
                
                count = call_method("forward", G_ARRAY);
                
                SPAGAIN;
                
                /* Debug output */
                
                /* Collect logits into array - pop in reverse order */
                av_extend(logits_av, count - 1);
                for (j = count - 1; j >= 0; j--) {
                    SV *val = POPs;
                    av_store(logits_av, j, newSVnv(SvNV(val)));
                }
                
                PUTBACK;
                FREETMPS;
                LEAVE;
            }
            
            SvREFCNT_dec(input_ref);
            n_vocab = av_len(logits_av) + 1;
            
            
            /* Sample next token */
            if (greedy) {
                /* Argmax for greedy sampling */
                float max_val = -1e9f;
                next_token = 0;
                for (j = 0; j < n_vocab; j++) {
                    SV **svp = av_fetch(logits_av, j, 0);
                    float val = svp ? SvNV(*svp) : 0.0f;
                    if (val > max_val) {
                        max_val = val;
                        next_token = j;
                    }
                }
            } else if (top_k > 0 && top_k < 1000) {
                /* Use top_k sampling inline */
                float *probs;
                int *indices;
                float max_logit = -1e9f;
                float sum = 0.0f;
                float threshold, cumsum;
                int k;
                
                Newx(probs, n_vocab, float);
                Newx(indices, n_vocab, int);
                
                /* Copy logits and find max */
                for (j = 0; j < n_vocab; j++) {
                    SV **svp = av_fetch(logits_av, j, 0);
                    probs[j] = svp ? SvNV(*svp) : 0.0f;
                    if (probs[j] > max_logit) max_logit = probs[j];
                    indices[j] = j;
                }
                
                /* Apply temperature and softmax */
                for (j = 0; j < n_vocab; j++) {
                    probs[j] = expf((probs[j] - max_logit) / temperature);
                    sum += probs[j];
                }
                for (j = 0; j < n_vocab; j++) {
                    probs[j] /= sum;
                }
                
                /* Partial sort to get top_k elements */
                for (k = 0; k < top_k && k < n_vocab; k++) {
                    int max_idx = k;
                    for (j = k + 1; j < n_vocab; j++) {
                        if (probs[j] > probs[max_idx]) max_idx = j;
                    }
                    if (max_idx != k) {
                        float tmp = probs[k];
                        int tmp_idx = indices[k];
                        probs[k] = probs[max_idx];
                        indices[k] = indices[max_idx];
                        probs[max_idx] = tmp;
                        indices[max_idx] = tmp_idx;
                    }
                }
                
                /* Renormalize and sample */
                sum = 0.0f;
                for (k = 0; k < top_k && k < n_vocab; k++) sum += probs[k];
                threshold = (float)rand() / (float)RAND_MAX * sum;
                cumsum = 0.0f;
                next_token = indices[0];
                for (k = 0; k < top_k && k < n_vocab; k++) {
                    cumsum += probs[k];
                    if (cumsum >= threshold) {
                        next_token = indices[k];
                        break;
                    }
                }
                
                Safefree(probs);
                Safefree(indices);
            } else {
                /* Use top_p sampling inline */
                float *probs;
                int *indices;
                float max_logit = -1e9f;
                float sum = 0.0f;
                float threshold, cumsum;
                
                Newx(probs, n_vocab, float);
                Newx(indices, n_vocab, int);
                
                /* Copy logits and find max */
                for (j = 0; j < n_vocab; j++) {
                    SV **svp = av_fetch(logits_av, j, 0);
                    probs[j] = svp ? SvNV(*svp) : 0.0f;
                    if (probs[j] > max_logit) max_logit = probs[j];
                    indices[j] = j;
                }
                
                /* Apply temperature and softmax */
                for (j = 0; j < n_vocab; j++) {
                    probs[j] = expf((probs[j] - max_logit) / temperature);
                    sum += probs[j];
                }
                for (j = 0; j < n_vocab; j++) {
                    probs[j] /= sum;
                }
                
                /* Sort by probability (bubble sort with early exit) */
                {
                    int swapped, k;
                    for (j = 0; j < n_vocab - 1; j++) {
                        swapped = 0;
                        cumsum = 0;
                        for (k = 0; k < n_vocab - j - 1; k++) {
                            if (probs[k] < probs[k + 1]) {
                                float tmp = probs[k];
                                int tmp_idx = indices[k];
                                probs[k] = probs[k + 1];
                                indices[k] = indices[k + 1];
                                probs[k + 1] = tmp;
                                indices[k + 1] = tmp_idx;
                                swapped = 1;
                            }
                        }
                        if (!swapped) break;
                        for (k = 0; k <= j; k++) cumsum += probs[k];
                        if (cumsum >= top_p) break;
                    }
                }
                
                /* Sample from top_p tokens */
                threshold = (float)rand() / (float)RAND_MAX * top_p;
                cumsum = 0.0f;
                next_token = indices[0];
                for (j = 0; j < n_vocab; j++) {
                    cumsum += probs[j];
                    if (cumsum >= threshold) {
                        next_token = indices[j];
                        break;
                    }
                    if (cumsum >= top_p) break;
                }
                
                Safefree(probs);
                Safefree(indices);
            }
            
            SvREFCNT_dec((SV*)logits_av);
            
            /* Add to results */
            av_push(result_av, newSViv(next_token));
            tokens[current_len] = next_token;
            current_len++;
            gen_count++;
            
            /* Call streaming callback if provided */
            if (callback) {
                dSP;
                int should_stop;
                
                ENTER;
                SAVETMPS;
                
                PUSHMARK(SP);
                XPUSHs(sv_2mortal(newSViv(next_token)));
                XPUSHs(sv_2mortal(newSViv(gen_count)));
                PUTBACK;
                
                call_sv(callback, G_SCALAR);
                
                SPAGAIN;
                should_stop = POPi;
                
                PUTBACK;
                FREETMPS;
                LEAVE;
                
                /* Callback returns true to stop generation */
                if (should_stop) break;
            }
            
            /* Check for EOS token */
            if (next_token == eos_token) break;
        }
    }
    
    Safefree(tokens);
    
    /* Return generated tokens as list */
    n_result = av_len(result_av) + 1;
    {
        /* Use XSRETURN explicitly */
        int count = 0;
        SP = orig_sp;  /* Restore original stack pointer */
        for (i = 0; i < n_result; i++) {
            SV **svp = av_fetch(result_av, i, 0);
            if (svp) {
                XST_mIV(count, SvIV(*svp));
                count++;
            }
        }
        SvREFCNT_dec(result_av);
        XSRETURN(count);
    }

MODULE = Lugh  PACKAGE = Lugh::Tokenizer

=pod

=head1 Lugh::Tokenizer

Tokenizer for encoding text to tokens and decoding tokens back to text.
Reads vocabulary from the GGUF model file.

=cut

SV *
new(class, ...)
    const char *class
PREINIT:
    LughModel *model = NULL;
    SV *model_sv = NULL;
    int i;
CODE:
    /* Parse arguments */
    if ((items - 1) % 2 != 0) {
        croak("Usage: Lugh::Tokenizer->new(model => $model)");
    }
    
    for (i = 1; i < items; i += 2) {
        const char *key = SvPV_nolen(ST(i));
        SV *val = ST(i + 1);
        
        if (strEQ(key, "model")) {
            model_sv = val;
            model = get_lugh_model(aTHX_ val);
        }
    }
    
    if (!model) {
        croak("model parameter is required");
    }
    
    {
        HV *hv = newHV();
        HV *token_to_id = newHV();
        AV *id_to_token = newAV();
        int64_t tokens_key, scores_key, merges_key;
        SV *sv;
        int64_t n_vocab;
        int64_t j;
        int bos_id = 1, eos_id = 2, unk_id = 0;
        int64_t key_id;
        
        /* Get special token IDs */
        key_id = gguf_find_key(model->gguf, "tokenizer.ggml.bos_token_id");
        if (key_id >= 0) bos_id = gguf_get_val_u32(model->gguf, key_id);
        
        key_id = gguf_find_key(model->gguf, "tokenizer.ggml.eos_token_id");
        if (key_id >= 0) eos_id = gguf_get_val_u32(model->gguf, key_id);
        
        key_id = gguf_find_key(model->gguf, "tokenizer.ggml.unknown_token_id");
        if (key_id >= 0) unk_id = gguf_get_val_u32(model->gguf, key_id);
        
        /* Load vocabulary */
        tokens_key = gguf_find_key(model->gguf, "tokenizer.ggml.tokens");
        if (tokens_key < 0) {
            croak("No vocabulary found in model");
        }
        
        n_vocab = gguf_get_arr_n(model->gguf, tokens_key);
        av_extend(id_to_token, n_vocab - 1);
        
        for (j = 0; j < n_vocab; j++) {
            const char *tok = gguf_get_arr_str(model->gguf, tokens_key, j);
            STRLEN len = strlen(tok);
            
            /* Store in id_to_token array */
            av_store(id_to_token, j, newSVpv(tok, len));
            
            /* Store in token_to_id hash */
            hv_store(token_to_id, tok, len, newSViv(j), 0);
        }
        
        hv_store(hv, "_model", 6, SvREFCNT_inc(model_sv), 0);
        hv_store(hv, "_token_to_id", 12, newRV_noinc((SV*)token_to_id), 0);
        hv_store(hv, "_id_to_token", 12, newRV_noinc((SV*)id_to_token), 0);
        hv_store(hv, "n_vocab", 7, newSViv(n_vocab), 0);
        hv_store(hv, "bos_id", 6, newSViv(bos_id), 0);
        hv_store(hv, "eos_id", 6, newSViv(eos_id), 0);
        hv_store(hv, "unk_id", 6, newSViv(unk_id), 0);
        
        sv = newRV_noinc((SV*)hv);
        sv_bless(sv, gv_stashpv(class, GV_ADD));
        RETVAL = sv;
    }
OUTPUT:
    RETVAL

int
n_vocab(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "n_vocab", 7, 0);
    RETVAL = svp ? SvIV(*svp) : 0;
OUTPUT:
    RETVAL

int
bos_id(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "bos_id", 6, 0);
    RETVAL = svp ? SvIV(*svp) : 1;
OUTPUT:
    RETVAL

int
eos_id(self)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "eos_id", 6, 0);
    RETVAL = svp ? SvIV(*svp) : 2;
OUTPUT:
    RETVAL

SV *
decode(self, ...)
    SV *self
PREINIT:
    HV *hv;
    SV **svp;
    AV *id_to_token;
    SV *result;
    int i;
    int skip_special = 0;
CODE:
    hv = (HV*)SvRV(self);
    svp = hv_fetch(hv, "_id_to_token", 12, 0);
    if (!svp || !SvROK(*svp)) {
        croak("Tokenizer not initialized properly");
    }
    id_to_token = (AV*)SvRV(*svp);
    
    result = newSVpv("", 0);
    
    /* Handle array reference or list of token ids */
    if (items == 2 && SvROK(ST(1)) && SvTYPE(SvRV(ST(1))) == SVt_PVAV) {
        /* Array reference passed */
        AV *av = (AV*)SvRV(ST(1));
        int n = av_len(av) + 1;
        for (i = 0; i < n; i++) {
            SV **elem = av_fetch(av, i, 0);
            if (elem && *elem) {
                int token_id = SvIV(*elem);
                SV **tokp = av_fetch(id_to_token, token_id, 0);
                if (tokp && *tokp) {
                    const char *tok = SvPV_nolen(*tokp);
                    /* Skip special tokens like <s>, </s>, etc if needed */
                    if (tok[0] != '<' || !strchr(tok, '>')) {
                        /* Handle SentencePiece underscore prefix (▁ -> space) */
                        if ((unsigned char)tok[0] == 0xE2 && 
                            (unsigned char)tok[1] == 0x96 && 
                            (unsigned char)tok[2] == 0x81) {
                            sv_catpvn(result, " ", 1);
                            sv_catpv(result, tok + 3);
                        } else {
                            sv_catpv(result, tok);
                        }
                    }
                }
            }
        }
    } else {
        /* List of token ids passed directly */
        for (i = 1; i < items; i++) {
            int token_id = SvIV(ST(i));
            SV **tokp = av_fetch(id_to_token, token_id, 0);
            if (tokp && *tokp) {
                const char *tok = SvPV_nolen(*tokp);
                /* Skip special tokens like <s>, </s>, etc if needed */
                if (tok[0] != '<' || !strchr(tok, '>')) {
                    /* Handle SentencePiece underscore prefix (▁ -> space) */
                    if ((unsigned char)tok[0] == 0xE2 && 
                        (unsigned char)tok[1] == 0x96 && 
                        (unsigned char)tok[2] == 0x81) {
                        sv_catpvn(result, " ", 1);
                        sv_catpv(result, tok + 3);
                    } else {
                        sv_catpv(result, tok);
                    }
                }
            }
        }
    }
    
    RETVAL = result;
OUTPUT:
    RETVAL

void
encode(self, text, ...)
    SV *self
    SV *text
PREINIT:
    HV *hv;
    SV **svp;
    HV *token_to_id;
    AV *id_to_token;
    const char *str;
    STRLEN len;
    AV *tokens;
    int add_bos = 1;
    int bos_id, eos_id, unk_id;
    size_t pos;
    int i;
PPCODE:
    hv = (HV*)SvRV(self);
    
    svp = hv_fetch(hv, "_token_to_id", 12, 0);
    if (!svp || !SvROK(*svp)) croak("Tokenizer not initialized");
    token_to_id = (HV*)SvRV(*svp);
    
    svp = hv_fetch(hv, "_id_to_token", 12, 0);
    if (!svp || !SvROK(*svp)) croak("Tokenizer not initialized");
    id_to_token = (AV*)SvRV(*svp);
    
    svp = hv_fetch(hv, "bos_id", 6, 0);
    bos_id = svp ? SvIV(*svp) : 1;
    svp = hv_fetch(hv, "eos_id", 6, 0);
    eos_id = svp ? SvIV(*svp) : 2;
    svp = hv_fetch(hv, "unk_id", 6, 0);
    unk_id = svp ? SvIV(*svp) : 0;
    
    /* Parse optional add_bos parameter */
    for (i = 2; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "add_bos")) {
                add_bos = SvIV(ST(i + 1));
            }
        }
    }
    
    str = SvPV(text, len);
    
    /* Simple greedy tokenization (longest match first) */
    /* For production, should use proper BPE merge algorithm */
    
    if (add_bos) {
        XPUSHs(sv_2mortal(newSViv(bos_id)));
    }
    
    pos = 0;
    while (pos < len) {
        int best_len = 0;
        int best_id = unk_id;
        int try_len;
        char buf[256];
        int at_word_start = (pos == 0 || str[pos-1] == ' ' || str[pos-1] == '\n' || str[pos-1] == '\t');
        
        /* Skip space - it becomes part of the next token's ▁ prefix */
        if (str[pos] == ' ' || str[pos] == '\t') {
            pos++;
            continue;
        }
        
        /* Try to find longest matching token */
        for (try_len = (len - pos > 255 ? 255 : len - pos); try_len > 0; try_len--) {
            SV **id_ptr;
            
            /* Copy substring to buffer */
            memcpy(buf, str + pos, try_len);
            buf[try_len] = '\0';
            
            /* Try with SentencePiece prefix for word start */
            if (at_word_start) {
                char sp_buf[260];
                /* ▁ = 0xE2 0x96 0x81 in UTF-8 */
                sp_buf[0] = 0xE2;
                sp_buf[1] = 0x96;
                sp_buf[2] = 0x81;
                memcpy(sp_buf + 3, buf, try_len + 1);
                id_ptr = hv_fetch(token_to_id, sp_buf, try_len + 3, 0);
                if (id_ptr && *id_ptr) {
                    best_id = SvIV(*id_ptr);
                    best_len = try_len;
                    break;
                }
            }
            
            /* Try without prefix */
            id_ptr = hv_fetch(token_to_id, buf, try_len, 0);
            if (id_ptr && *id_ptr) {
                best_id = SvIV(*id_ptr);
                best_len = try_len;
                break;
            }
        }
        
        if (best_len == 0) {
            /* Skip unknown character */
            pos++;
            XPUSHs(sv_2mortal(newSViv(unk_id)));
        } else {
            pos += best_len;
            XPUSHs(sv_2mortal(newSViv(best_id)));
        }
    }

#ifdef USE_ITHREADS

void
CLONE(class)
    char *class
CODE:
    /*
     * Thread cloning: contexts cannot be shared across threads
     * because ggml_context is not thread-safe. Each thread must
     * create its own contexts. We invalidate cloned contexts here.
     * 
     * Tensors also cannot be shared as they belong to a context.
     */
    PERL_UNUSED_VAR(class);
    /* 
     * Note: The magic cleanup will still be called but the context
     * has been invalidated in the clone. New contexts must be created.
     */

#endif

MODULE = Lugh    PACKAGE = Lugh::Tensor

SV *
new_f32(class, ctx_sv, ...)
    char *class
    SV *ctx_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *tensor = NULL;
    int64_t ne[4] = {1, 1, 1, 1};
    int n_dims = 1;
    int i;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    
    /* Parse dimensions */
    n_dims = items - 2;
    if (n_dims < 1) n_dims = 1;
    if (n_dims > 4) croak("Maximum 4 dimensions supported");
    
    for (i = 0; i < n_dims; i++) {
        ne[i] = SvIV(ST(i + 2));
    }
    
    /* Create tensor based on dimensionality */
    switch (n_dims) {
        case 1:
            tensor = ggml_new_tensor_1d(lctx->ctx, GGML_TYPE_F32, ne[0]);
            break;
        case 2:
            tensor = ggml_new_tensor_2d(lctx->ctx, GGML_TYPE_F32, ne[0], ne[1]);
            break;
        case 3:
            tensor = ggml_new_tensor_3d(lctx->ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2]);
            break;
        case 4:
            tensor = ggml_new_tensor_4d(lctx->ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], ne[3]);
            break;
    }
    
    if (!tensor) {
        croak("Failed to create tensor");
    }
    
    /* Return tensor pointer as blessed IV */
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(tensor))),
        gv_stashpv(class, GV_ADD)
    );
OUTPUT:
    RETVAL

void
set_f32(self, ...)
    SV *self
PREINIT:
    struct ggml_tensor *tensor;
    int64_t i, n_elements;
CODE:
    tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(self)));
    n_elements = ggml_nelements(tensor);
    
    if (items - 1 != n_elements) {
        croak("Expected %ld values, got %d", (long)n_elements, (int)(items - 1));
    }
    
    for (i = 0; i < n_elements; i++) {
        ggml_set_f32_1d(tensor, i, SvNV(ST(i + 1)));
    }

void
get_f32(self)
    SV *self
PREINIT:
    struct ggml_tensor *tensor;
    int64_t i, n_elements;
PPCODE:
    tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(self)));
    n_elements = ggml_nelements(tensor);
    
    EXTEND(SP, n_elements);
    for (i = 0; i < n_elements; i++) {
        mPUSHn(ggml_get_f32_1d(tensor, i));
    }

int64_t
nelements(self)
    SV *self
CODE:
    struct ggml_tensor *tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(self)));
    RETVAL = ggml_nelements(tensor);
OUTPUT:
    RETVAL

int
n_dims(self)
    SV *self
CODE:
    struct ggml_tensor *tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(self)));
    RETVAL = ggml_n_dims(tensor);
OUTPUT:
    RETVAL

void
shape(self)
    SV *self
PREINIT:
    struct ggml_tensor *tensor;
    int i, n_dims;
PPCODE:
    tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(self)));
    n_dims = ggml_n_dims(tensor);
    
    EXTEND(SP, n_dims);
    for (i = 0; i < n_dims; i++) {
        mPUSHi(tensor->ne[i]);
    }

MODULE = Lugh    PACKAGE = Lugh::Ops

SV *
add(ctx_sv, a_sv, b_sv)
    SV *ctx_sv
    SV *a_sv
    SV *b_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *b, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    b = INT2PTR(struct ggml_tensor *, SvIV(SvRV(b_sv)));
    
    result = ggml_add(lctx->ctx, a, b);
    if (!result) {
        croak("ggml_add failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

SV *
mul(ctx_sv, a_sv, b_sv)
    SV *ctx_sv
    SV *a_sv
    SV *b_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *b, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    b = INT2PTR(struct ggml_tensor *, SvIV(SvRV(b_sv)));
    
    result = ggml_mul(lctx->ctx, a, b);
    if (!result) {
        croak("ggml_mul failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

SV *
mul_mat(ctx_sv, a_sv, b_sv)
    SV *ctx_sv
    SV *a_sv
    SV *b_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *b, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    b = INT2PTR(struct ggml_tensor *, SvIV(SvRV(b_sv)));
    
    result = ggml_mul_mat(lctx->ctx, a, b);
    if (!result) {
        croak("ggml_mul_mat failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

SV *
soft_max(ctx_sv, a_sv)
    SV *ctx_sv
    SV *a_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    
    result = ggml_soft_max(lctx->ctx, a);
    if (!result) {
        croak("ggml_soft_max failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

SV *
rms_norm(ctx_sv, a_sv, eps)
    SV *ctx_sv
    SV *a_sv
    float eps
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    
    result = ggml_rms_norm(lctx->ctx, a, eps);
    if (!result) {
        croak("ggml_rms_norm failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

SV *
silu(ctx_sv, a_sv)
    SV *ctx_sv
    SV *a_sv
PREINIT:
    LughContext *lctx;
    struct ggml_tensor *a, *result;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    a = INT2PTR(struct ggml_tensor *, SvIV(SvRV(a_sv)));
    
    result = ggml_silu(lctx->ctx, a);
    if (!result) {
        croak("ggml_silu failed");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(result))),
        gv_stashpv("Lugh::Tensor", GV_ADD)
    );
OUTPUT:
    RETVAL

MODULE = Lugh    PACKAGE = Lugh::Graph

SV *
new(class, ctx_sv)
    char *class
    SV *ctx_sv
PREINIT:
    LughContext *lctx;
    struct ggml_cgraph *graph;
CODE:
    lctx = get_lugh_context(aTHX_ ctx_sv);
    graph = ggml_new_graph(lctx->ctx);
    
    if (!graph) {
        croak("Failed to create computation graph");
    }
    
    RETVAL = sv_bless(
        newRV_noinc(newSViv(PTR2IV(graph))),
        gv_stashpv(class, GV_ADD)
    );
OUTPUT:
    RETVAL

void
build_forward(self, tensor_sv)
    SV *self
    SV *tensor_sv
CODE:
    struct ggml_cgraph *graph = INT2PTR(struct ggml_cgraph *, SvIV(SvRV(self)));
    struct ggml_tensor *tensor = INT2PTR(struct ggml_tensor *, SvIV(SvRV(tensor_sv)));
    
    ggml_build_forward_expand(graph, tensor);

void
compute(self, ctx_sv, n_threads)
    SV *self
    SV *ctx_sv
    int n_threads
CODE:
    LughContext *lctx = get_lugh_context(aTHX_ ctx_sv);
    struct ggml_cgraph *graph = INT2PTR(struct ggml_cgraph *, SvIV(SvRV(self)));
    
    ggml_graph_compute_with_ctx(lctx->ctx, graph, n_threads);

MODULE = Lugh    PACKAGE = Lugh::Model

SV *
new(class, ...)
    char *class
PREINIT:
    LughModel *lm;
    const char *filename = NULL;
    struct gguf_init_params gguf_params;
    struct ggml_context *tensor_ctx = NULL;
    SV *sv;
    int i, id;
    int64_t key_id;
CODE:
    INIT_MUTEXES();
    
    /* Parse arguments */
    for (i = 1; i < items; i += 2) {
        if (i + 1 < items) {
            const char *key = SvPV_nolen(ST(i));
            if (strEQ(key, "model") || strEQ(key, "file") || strEQ(key, "path")) {
                filename = SvPV_nolen(ST(i + 1));
            }
        }
    }
    
    if (!filename) {
        croak("Lugh::Model->new requires 'model' parameter with path to GGUF file");
    }
    
    /* Allocate model ID */
    id = alloc_model_id();
    if (id < 0) {
        croak("Maximum number of models (%d) reached", MAX_CONTEXTS);
    }
    
    /* Allocate model structure */
    Newxz(lm, 1, LughModel);
    lm->id = id;
    lm->active = 1;
    
    /* Copy filename */
    Newx(lm->filename, strlen(filename) + 1, char);
    strcpy(lm->filename, filename);
    
    /* Initialize GGUF context */
    gguf_params.no_alloc = false;
    gguf_params.ctx = &tensor_ctx;
    
    lm->gguf = gguf_init_from_file(filename, gguf_params);
    if (!lm->gguf) {
        Safefree(lm->filename);
        Safefree(lm);
        croak("Failed to load GGUF file: %s", filename);
    }
    
    lm->ctx = tensor_ctx;
    lm->n_tensors = gguf_get_n_tensors(lm->gguf);
    lm->n_kv = gguf_get_n_kv(lm->gguf);
    
    /* Get architecture if available */
    key_id = gguf_find_key(lm->gguf, "general.architecture");
    if (key_id >= 0) {
        const char *arch = gguf_get_val_str(lm->gguf, key_id);
        Newx(lm->architecture, strlen(arch) + 1, char);
        strcpy(lm->architecture, arch);
    }
    
    /* Register in global registry */
    CONTEXT_LOCK();
    model_registry[id] = lm;
    CONTEXT_UNLOCK();
    
    /* Create blessed reference with magic */
    sv = newSV(0);
    sv_magicext(sv, NULL, PERL_MAGIC_ext, &lugh_model_vtbl, INT2PTR(char*, (IV)id), 0);
    RETVAL = sv_bless(newRV_noinc(sv), gv_stashpv(class, GV_ADD));
OUTPUT:
    RETVAL

const char *
filename(self)
    SV *self
CODE:
    LughModel *lm = get_lugh_model(aTHX_ self);
    RETVAL = lm->filename;
OUTPUT:
    RETVAL

const char *
architecture(self)
    SV *self
CODE:
    LughModel *lm = get_lugh_model(aTHX_ self);
    RETVAL = lm->architecture ? lm->architecture : "unknown";
OUTPUT:
    RETVAL

int64_t
n_tensors(self)
    SV *self
CODE:
    LughModel *lm = get_lugh_model(aTHX_ self);
    RETVAL = lm->n_tensors;
OUTPUT:
    RETVAL

int64_t
n_kv(self)
    SV *self
CODE:
    LughModel *lm = get_lugh_model(aTHX_ self);
    RETVAL = lm->n_kv;
OUTPUT:
    RETVAL

void
tensor_info(self, name)
    SV *self
    const char *name
PREINIT:
    LughModel *lm;
    struct ggml_tensor *t;
PPCODE:
    lm = get_lugh_model(aTHX_ self);
    t = ggml_get_tensor(lm->ctx, name);
    if (t) {
        /* Return: type, n_dims, ne[0], ne[1], ne[2], ne[3] */
        EXTEND(SP, 6);
        mPUSHi(t->type);
        mPUSHi(ggml_n_dims(t));
        mPUSHi(t->ne[0]);
        mPUSHi(t->ne[1]);
        mPUSHi(t->ne[2]);
        mPUSHi(t->ne[3]);
    }

void
tensor_names(self)
    SV *self
PREINIT:
    LughModel *lm;
    int64_t i;
PPCODE:
    lm = get_lugh_model(aTHX_ self);
    EXTEND(SP, lm->n_tensors);
    for (i = 0; i < lm->n_tensors; i++) {
        mPUSHs(newSVpv(gguf_get_tensor_name(lm->gguf, i), 0));
    }

void
kv_keys(self)
    SV *self
PREINIT:
    LughModel *lm;
    int64_t i;
PPCODE:
    lm = get_lugh_model(aTHX_ self);
    EXTEND(SP, lm->n_kv);
    for (i = 0; i < lm->n_kv; i++) {
        mPUSHs(newSVpv(gguf_get_key(lm->gguf, i), 0));
    }

SV *
get_kv(self, key)
    SV *self
    const char *key
PREINIT:
    LughModel *lm;
    int64_t key_id;
    enum gguf_type kv_type;
CODE:
    lm = get_lugh_model(aTHX_ self);
    key_id = gguf_find_key(lm->gguf, key);
    
    if (key_id < 0) {
        RETVAL = &PL_sv_undef;
    } else {
        kv_type = gguf_get_kv_type(lm->gguf, key_id);
        switch (kv_type) {
            case GGUF_TYPE_UINT8:
                RETVAL = newSVuv(gguf_get_val_u8(lm->gguf, key_id));
                break;
            case GGUF_TYPE_INT8:
                RETVAL = newSViv(gguf_get_val_i8(lm->gguf, key_id));
                break;
            case GGUF_TYPE_UINT16:
                RETVAL = newSVuv(gguf_get_val_u16(lm->gguf, key_id));
                break;
            case GGUF_TYPE_INT16:
                RETVAL = newSViv(gguf_get_val_i16(lm->gguf, key_id));
                break;
            case GGUF_TYPE_UINT32:
                RETVAL = newSVuv(gguf_get_val_u32(lm->gguf, key_id));
                break;
            case GGUF_TYPE_INT32:
                RETVAL = newSViv(gguf_get_val_i32(lm->gguf, key_id));
                break;
            case GGUF_TYPE_UINT64:
                RETVAL = newSVuv(gguf_get_val_u64(lm->gguf, key_id));
                break;
            case GGUF_TYPE_INT64:
                RETVAL = newSViv(gguf_get_val_i64(lm->gguf, key_id));
                break;
            case GGUF_TYPE_FLOAT32:
                RETVAL = newSVnv(gguf_get_val_f32(lm->gguf, key_id));
                break;
            case GGUF_TYPE_FLOAT64:
                RETVAL = newSVnv(gguf_get_val_f64(lm->gguf, key_id));
                break;
            case GGUF_TYPE_BOOL:
                RETVAL = gguf_get_val_bool(lm->gguf, key_id) ? &PL_sv_yes : &PL_sv_no;
                break;
            case GGUF_TYPE_STRING:
                RETVAL = newSVpv(gguf_get_val_str(lm->gguf, key_id), 0);
                break;
            case GGUF_TYPE_ARRAY:
                /* Return array reference */
                {
                    enum gguf_type arr_type = gguf_get_arr_type(lm->gguf, key_id);
                    size_t n = gguf_get_arr_n(lm->gguf, key_id);
                    AV *av = newAV();
                    size_t j;
                    
                    if (arr_type == GGUF_TYPE_STRING) {
                        for (j = 0; j < n; j++) {
                            av_push(av, newSVpv(gguf_get_arr_str(lm->gguf, key_id, j), 0));
                        }
                    } else {
                        /* For numeric arrays, return as-is for now */
                        /* TODO: decode based on arr_type */
                    }
                    RETVAL = newRV_noinc((SV*)av);
                }
                break;
            default:
                RETVAL = &PL_sv_undef;
        }
    }
OUTPUT:
    RETVAL

void
DESTROY(self)
    SV *self
CODE:
    /* Magic cleanup handles this */
    PERL_UNUSED_VAR(self);
